[llvm] [VPlan] Move predication to VPlanTransform (NFC) (WIP). (PR #128420)

Florian Hahn via llvm-commits llvm-commits at lists.llvm.org
Sat May 10 04:47:49 PDT 2025


https://github.com/fhahn updated https://github.com/llvm/llvm-project/pull/128420

>From d0d2c2ed878bff39d8c222a0939aa6c7ad961837 Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Mon, 28 Apr 2025 19:35:23 +0100
Subject: [PATCH 1/6] [VPlan] Handle early exit before forming regions. (NFC)

Move early-exit handling up front to original VPlan construction, before
introducing early exits.

This builds on https://github.com/llvm/llvm-project/pull/137709, which
adds exiting edges to the original VPlan, instead of adding exit blocks
later.

This retains the exit conditions early, and means we can handle early
exits before forming regions, without the reliance on VPRecipeBuilder.

Once we retain all exits initially, handling early exits before region
construction ensures the regions are valid; otherwise we would leave
edges exiting the region from elsewhere than the latch.

Removing the reliance on VPRecipeBuilder removes the dependence on
mapping IR BBs to VPBBs and unblocks predication as VPlan transform:
https://github.com/llvm/llvm-project/pull/128420.

Depends on https://github.com/llvm/llvm-project/pull/137709.
---
 .../Transforms/Vectorize/LoopVectorize.cpp    |  12 +-
 .../Vectorize/VPlanConstruction.cpp           |  33 +++---
 .../Transforms/Vectorize/VPlanTransforms.cpp  | 108 +++++++++++-------
 .../Transforms/Vectorize/VPlanTransforms.h    |   8 +-
 .../Transforms/Vectorize/VPlanTestBase.h      |   4 +-
 5 files changed, 94 insertions(+), 71 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 9208fc45a0188..6c5d543ad9fa7 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -9384,7 +9384,8 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range) {
   VPlanTransforms::prepareForVectorization(
       *Plan, Legal->getWidestInductionType(), PSE, RequiresScalarEpilogueCheck,
       CM.foldTailByMasking(), OrigLoop,
-      getDebugLocFromInstOrOperands(Legal->getPrimaryInduction()));
+      getDebugLocFromInstOrOperands(Legal->getPrimaryInduction()),
+      Legal->hasUncountableEarlyExit(), Range);
   VPlanTransforms::createLoopRegions(*Plan);
 
   // Don't use getDecisionAndClampRange here, because we don't know the UF
@@ -9582,12 +9583,6 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range) {
     R->setOperand(1, WideIV->getStepValue());
   }
 
-  if (auto *UncountableExitingBlock =
-          Legal->getUncountableEarlyExitingBlock()) {
-    VPlanTransforms::runPass(VPlanTransforms::handleUncountableEarlyExit, *Plan,
-                             OrigLoop, UncountableExitingBlock, RecipeBuilder,
-                             Range);
-  }
   DenseMap<VPValue *, VPValue *> IVEndValues;
   addScalarResumePhis(RecipeBuilder, *Plan, IVEndValues);
   SetVector<VPIRInstruction *> ExitUsersToFix =
@@ -9685,7 +9680,8 @@ VPlanPtr LoopVectorizationPlanner::tryToBuildVPlan(VFRange &Range) {
   auto Plan = VPlanTransforms::buildPlainCFG(OrigLoop, *LI, VPB2IRBB);
   VPlanTransforms::prepareForVectorization(
       *Plan, Legal->getWidestInductionType(), PSE, true, false, OrigLoop,
-      getDebugLocFromInstOrOperands(Legal->getPrimaryInduction()));
+      getDebugLocFromInstOrOperands(Legal->getPrimaryInduction()), false,
+      Range);
   VPlanTransforms::createLoopRegions(*Plan);
 
   for (ElementCount VF : Range)
diff --git a/llvm/lib/Transforms/Vectorize/VPlanConstruction.cpp b/llvm/lib/Transforms/Vectorize/VPlanConstruction.cpp
index a0edd296caab8..73420b406b8e3 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanConstruction.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanConstruction.cpp
@@ -460,11 +460,10 @@ static void addCanonicalIVRecipes(VPlan &Plan, VPBasicBlock *HeaderVPBB,
                        {CanonicalIVIncrement, &Plan.getVectorTripCount()}, DL);
 }
 
-void VPlanTransforms::prepareForVectorization(VPlan &Plan, Type *InductionTy,
-                                              PredicatedScalarEvolution &PSE,
-                                              bool RequiresScalarEpilogueCheck,
-                                              bool TailFolded, Loop *TheLoop,
-                                              DebugLoc IVDL) {
+void VPlanTransforms::prepareForVectorization(
+    VPlan &Plan, Type *InductionTy, PredicatedScalarEvolution &PSE,
+    bool RequiresScalarEpilogueCheck, bool TailFolded, Loop *TheLoop,
+    DebugLoc IVDL, bool HandleUncountableExit, VFRange &Range) {
   VPDominatorTree VPDT;
   VPDT.recalculate(Plan);
 
@@ -491,16 +490,20 @@ void VPlanTransforms::prepareForVectorization(VPlan &Plan, Type *InductionTy,
   addCanonicalIVRecipes(Plan, cast<VPBasicBlock>(HeaderVPB),
                         cast<VPBasicBlock>(LatchVPB), InductionTy, IVDL);
 
-  // Disconnect all edges to exit blocks other than from the middle block.
-  // TODO: VPlans with early exits should be explicitly converted to a form
-  // exiting only via the latch here, including adjusting the exit condition,
-  // instead of simply disconnecting the edges and adjusting the VPlan later.
-  for (VPBlockBase *EB : Plan.getExitBlocks()) {
-    for (VPBlockBase *Pred : to_vector(EB->getPredecessors())) {
-      if (Pred == MiddleVPBB)
-        continue;
-      cast<VPBasicBlock>(Pred)->getTerminator()->eraseFromParent();
-      VPBlockUtils::disconnectBlocks(Pred, EB);
+  if (HandleUncountableExit) {
+    // Convert VPlans with early exits to a form only exiting via the latch
+    // here, including adjusting the exit condition.
+    handleUncountableEarlyExit(Plan, cast<VPBasicBlock>(HeaderVPB),
+                               cast<VPBasicBlock>(LatchVPB), Range);
+  } else {
+    // Disconnect all edges to exit blocks other than from the middle block.
+    for (VPBlockBase *EB : to_vector(Plan.getExitBlocks())) {
+      for (VPBlockBase *Pred : to_vector(EB->getPredecessors())) {
+        if (Pred == MiddleVPBB)
+          continue;
+        cast<VPBasicBlock>(Pred)->getTerminator()->eraseFromParent();
+        VPBlockUtils::disconnectBlocks(Pred, EB);
+      }
     }
   }
 
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index 79ddb8bf0b09b..3a6c5bc02cdf1 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -2458,64 +2458,86 @@ void VPlanTransforms::convertToConcreteRecipes(VPlan &Plan,
     R->eraseFromParent();
 }
 
-void VPlanTransforms::handleUncountableEarlyExit(
-    VPlan &Plan, Loop *OrigLoop, BasicBlock *UncountableExitingBlock,
-    VPRecipeBuilder &RecipeBuilder, VFRange &Range) {
-  VPRegionBlock *LoopRegion = Plan.getVectorLoopRegion();
-  auto *LatchVPBB = cast<VPBasicBlock>(LoopRegion->getExiting());
+void VPlanTransforms::handleUncountableEarlyExit(VPlan &Plan,
+                                                 VPBasicBlock *HeaderVPBB,
+                                                 VPBasicBlock *LatchVPBB,
+                                                 VFRange &Range) {
+  // First find the uncountable early exiting block by looking at the
+  // predecessors of the exit blocks.
+  VPBlockBase *MiddleVPBB = LatchVPBB->getSuccessors()[0];
+  VPBasicBlock *EarlyExitingVPBB = nullptr;
+  VPIRBasicBlock *EarlyExitVPBB = nullptr;
+  for (auto *EB : Plan.getExitBlocks()) {
+    for (VPBlockBase *Pred : EB->getPredecessors()) {
+      if (Pred != MiddleVPBB) {
+        EarlyExitingVPBB = cast<VPBasicBlock>(Pred);
+        EarlyExitVPBB = EB;
+        break;
+      }
+    }
+  }
+  assert(EarlyExitVPBB && "Must have a early exiting block.");
+  assert(all_of(Plan.getExitBlocks(),
+                [EarlyExitingVPBB, MiddleVPBB](VPIRBasicBlock *EB) {
+                  return all_of(
+                      EB->getPredecessors(),
+                      [EarlyExitingVPBB, MiddleVPBB](VPBlockBase *Pred) {
+                        return Pred == EarlyExitingVPBB || Pred == MiddleVPBB;
+                      });
+                }) &&
+         "All exit blocks must only have EarlyExitingVPBB or MiddleVPBB as "
+         "predecessors.");
+
   VPBuilder Builder(LatchVPBB->getTerminator());
-  auto *MiddleVPBB = Plan.getMiddleBlock();
-  VPValue *IsEarlyExitTaken = nullptr;
-
-  // Process the uncountable exiting block. Update IsEarlyExitTaken, which
-  // tracks if the uncountable early exit has been taken. Also split the middle
-  // block and have it conditionally branch to the early exit block if
-  // EarlyExitTaken.
-  auto *EarlyExitingBranch =
-      cast<BranchInst>(UncountableExitingBlock->getTerminator());
-  BasicBlock *TrueSucc = EarlyExitingBranch->getSuccessor(0);
-  BasicBlock *FalseSucc = EarlyExitingBranch->getSuccessor(1);
-  BasicBlock *EarlyExitIRBB =
-      !OrigLoop->contains(TrueSucc) ? TrueSucc : FalseSucc;
-  VPIRBasicBlock *VPEarlyExitBlock = Plan.getExitBlock(EarlyExitIRBB);
-
-  VPValue *EarlyExitNotTakenCond = RecipeBuilder.getBlockInMask(
-      OrigLoop->contains(TrueSucc) ? TrueSucc : FalseSucc);
-  auto *EarlyExitTakenCond = Builder.createNot(EarlyExitNotTakenCond);
-  IsEarlyExitTaken =
-      Builder.createNaryOp(VPInstruction::AnyOf, {EarlyExitTakenCond});
+  VPBlockBase *TrueSucc = EarlyExitingVPBB->getSuccessors()[0];
+  VPValue *EarlyExitCond = EarlyExitingVPBB->getTerminator()->getOperand(0);
+  auto *EarlyExitTakenCond = TrueSucc == EarlyExitVPBB
+                                 ? EarlyExitCond
+                                 : Builder.createNot(EarlyExitCond);
+
+  if (!EarlyExitVPBB->getSinglePredecessor() &&
+      EarlyExitVPBB->getPredecessors()[0] != MiddleVPBB) {
+    for (VPRecipeBase &R : EarlyExitVPBB->phis()) {
+      // Early exit operand should always be last, i.e., 0 if EarlyExitVPBB has
+      // a single predecessor and 1 if it has two.
+      // If EarlyExitVPBB has two predecessors, they are already ordered such
+      // that early exit is second (and latch exit is first), by construction.
+      // But its underlying IRBB (EarlyExitIRBB) may have its predecessors
+      // ordered the other way around, and it is the order of the latter which
+      // corresponds to the order of operands of EarlyExitVPBB's phi recipes.
+      // Therefore, if early exit (UncountableExitingBlock) is the first
+      // predecessor of EarlyExitIRBB, we swap the operands of phi recipes,
+      // thereby bringing them to match EarlyExitVPBB's predecessor order,
+      // with early exit being last (second). Otherwise they already match.
+      cast<VPIRPhi>(&R)->swapOperands();
+    }
+  }
 
+  EarlyExitingVPBB->getTerminator()->eraseFromParent();
+  VPBlockUtils::disconnectBlocks(EarlyExitingVPBB, EarlyExitVPBB);
+
+  // Split the middle block and have it conditionally branch to the early exit
+  // block if EarlyExitTaken.
+  VPValue *IsEarlyExitTaken =
+      Builder.createNaryOp(VPInstruction::AnyOf, {EarlyExitTakenCond});
   VPBasicBlock *NewMiddle = Plan.createVPBasicBlock("middle.split");
   VPBasicBlock *VectorEarlyExitVPBB =
       Plan.createVPBasicBlock("vector.early.exit");
-  VPBlockUtils::insertOnEdge(LoopRegion, MiddleVPBB, NewMiddle);
+  VPBlockUtils::insertOnEdge(LatchVPBB, MiddleVPBB, NewMiddle);
   VPBlockUtils::connectBlocks(NewMiddle, VectorEarlyExitVPBB);
   NewMiddle->swapSuccessors();
 
-  VPBlockUtils::connectBlocks(VectorEarlyExitVPBB, VPEarlyExitBlock);
+  VPBlockUtils::connectBlocks(VectorEarlyExitVPBB, EarlyExitVPBB);
 
   // Update the exit phis in the early exit block.
   VPBuilder MiddleBuilder(NewMiddle);
   VPBuilder EarlyExitB(VectorEarlyExitVPBB);
-  for (VPRecipeBase &R : VPEarlyExitBlock->phis()) {
+  for (VPRecipeBase &R : EarlyExitVPBB->phis()) {
     auto *ExitIRI = cast<VPIRPhi>(&R);
-    // Early exit operand should always be last, i.e., 0 if VPEarlyExitBlock has
+    // Early exit operand should always be last, i.e., 0 if EarlyExitVPBB has
     // a single predecessor and 1 if it has two.
     unsigned EarlyExitIdx = ExitIRI->getNumOperands() - 1;
-    if (!VPEarlyExitBlock->getSinglePredecessor()) {
-      // If VPEarlyExitBlock has two predecessors, they are already ordered such
-      // that early exit is second (and latch exit is first), by construction.
-      // But its underlying IRBB (EarlyExitIRBB) may have its predecessors
-      // ordered the other way around, and it is the order of the latter which
-      // corresponds to the order of operands of VPEarlyExitBlock's phi recipes.
-      // Therefore, if early exit (UncountableExitingBlock) is the first
-      // predecessor of EarlyExitIRBB, we swap the operands of phi recipes,
-      // thereby bringing them to match VPEarlyExitBlock's predecessor order,
-      // with early exit being last (second). Otherwise they already match.
-      if (*pred_begin(VPEarlyExitBlock->getIRBasicBlock()) ==
-          UncountableExitingBlock)
-        ExitIRI->swapOperands();
-
+    if (!EarlyExitVPBB->getSinglePredecessor()) {
       // The first of two operands corresponds to the latch exit, via MiddleVPBB
       // predecessor. Extract its last lane.
       ExitIRI->extractLastLaneOfFirstOperand(MiddleBuilder);
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.h b/llvm/lib/Transforms/Vectorize/VPlanTransforms.h
index 7a05816f2e2da..adb984fc56bac 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.h
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.h
@@ -69,7 +69,8 @@ struct VPlanTransforms {
                                       PredicatedScalarEvolution &PSE,
                                       bool RequiresScalarEpilogueCheck,
                                       bool TailFolded, Loop *TheLoop,
-                                      DebugLoc IVDL);
+                                      DebugLoc IVDL, bool HandleUncountableExit,
+                                      VFRange &Range);
 
   /// Replace loops in \p Plan's flat CFG with VPRegionBlocks, turning \p Plan's
   /// flat CFG into a hierarchical CFG.
@@ -179,9 +180,8 @@ struct VPlanTransforms {
   ///    exit conditions
   ///  * splitting the original middle block to branch to the early exit block
   ///    if taken.
-  static void handleUncountableEarlyExit(VPlan &Plan, Loop *OrigLoop,
-                                         BasicBlock *UncountableExitingBlock,
-                                         VPRecipeBuilder &RecipeBuilder,
+  static void handleUncountableEarlyExit(VPlan &Plan, VPBasicBlock *HeaderVPBB,
+                                         VPBasicBlock *LatchVPBB,
                                          VFRange &Range);
 
   /// Lower abstract recipes to concrete ones, that can be codegen'd. Use \p
diff --git a/llvm/unittests/Transforms/Vectorize/VPlanTestBase.h b/llvm/unittests/Transforms/Vectorize/VPlanTestBase.h
index bf67a5596b270..15e21972840f6 100644
--- a/llvm/unittests/Transforms/Vectorize/VPlanTestBase.h
+++ b/llvm/unittests/Transforms/Vectorize/VPlanTestBase.h
@@ -13,6 +13,7 @@
 #define LLVM_UNITTESTS_TRANSFORMS_VECTORIZE_VPLANTESTBASE_H
 
 #include "../lib/Transforms/Vectorize/VPlan.h"
+#include "../lib/Transforms/Vectorize/VPlanHelpers.h"
 #include "../lib/Transforms/Vectorize/VPlanTransforms.h"
 #include "llvm/Analysis/AssumptionCache.h"
 #include "llvm/Analysis/BasicAliasAnalysis.h"
@@ -72,8 +73,9 @@ class VPlanTestIRBase : public testing::Test {
     PredicatedScalarEvolution PSE(*SE, *L);
     DenseMap<VPBlockBase *, BasicBlock *> VPB2IRBB;
     auto Plan = VPlanTransforms::buildPlainCFG(L, *LI, VPB2IRBB);
+    VFRange R(ElementCount::getFixed(1), ElementCount::getFixed(2));
     VPlanTransforms::prepareForVectorization(*Plan, IntegerType::get(*Ctx, 64),
-                                             PSE, true, false, L, {});
+                                             PSE, true, false, L, {}, false, R);
     VPlanTransforms::createLoopRegions(*Plan);
     return Plan;
   }

>From 76c470a914b38e32f7c40b234c23083911db68cc Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Thu, 8 May 2025 20:11:33 +0100
Subject: [PATCH 2/6] !fixup address comments, thanks!

---
 .../Vectorize/VPlanConstruction.cpp           | 40 +++++++----
 .../Transforms/Vectorize/VPlanTransforms.cpp  | 72 ++++++-------------
 .../Transforms/Vectorize/VPlanTransforms.h    |  8 ++-
 3 files changed, 54 insertions(+), 66 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/VPlanConstruction.cpp b/llvm/lib/Transforms/Vectorize/VPlanConstruction.cpp
index 73420b406b8e3..4270564fccec0 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanConstruction.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanConstruction.cpp
@@ -463,7 +463,7 @@ static void addCanonicalIVRecipes(VPlan &Plan, VPBasicBlock *HeaderVPBB,
 void VPlanTransforms::prepareForVectorization(
     VPlan &Plan, Type *InductionTy, PredicatedScalarEvolution &PSE,
     bool RequiresScalarEpilogueCheck, bool TailFolded, Loop *TheLoop,
-    DebugLoc IVDL, bool HandleUncountableExit, VFRange &Range) {
+    DebugLoc IVDL, bool HasUncountableEarlyExit, VFRange &Range) {
   VPDominatorTree VPDT;
   VPDT.recalculate(Plan);
 
@@ -490,23 +490,35 @@ void VPlanTransforms::prepareForVectorization(
   addCanonicalIVRecipes(Plan, cast<VPBasicBlock>(HeaderVPB),
                         cast<VPBasicBlock>(LatchVPB), InductionTy, IVDL);
 
-  if (HandleUncountableExit) {
-    // Convert VPlans with early exits to a form only exiting via the latch
-    // here, including adjusting the exit condition.
-    handleUncountableEarlyExit(Plan, cast<VPBasicBlock>(HeaderVPB),
-                               cast<VPBasicBlock>(LatchVPB), Range);
-  } else {
-    // Disconnect all edges to exit blocks other than from the middle block.
-    for (VPBlockBase *EB : to_vector(Plan.getExitBlocks())) {
-      for (VPBlockBase *Pred : to_vector(EB->getPredecessors())) {
-        if (Pred == MiddleVPBB)
-          continue;
-        cast<VPBasicBlock>(Pred)->getTerminator()->eraseFromParent();
-        VPBlockUtils::disconnectBlocks(Pred, EB);
+  [[maybe_unused]] bool HandledUncountableEarlyExit = false;
+  for (VPIRBasicBlock *EB : Plan.getExitBlocks()) {
+    for (VPBlockBase *Pred : to_vector(EB->getPredecessors())) {
+      if (Pred == MiddleVPBB)
+        continue;
+
+      if (HasUncountableEarlyExit) {
+        assert(!HandledUncountableEarlyExit &&
+               "can handle exactly one uncountable early exit");
+        // Convert VPlans with early exits to a form exiting only via the latch
+        // here, including adjusting the exit condition of the latch.
+        handleUncountableEarlyExit(cast<VPBasicBlock>(Pred), EB, Plan,
+                                   cast<VPBasicBlock>(HeaderVPB),
+                                   cast<VPBasicBlock>(LatchVPB), Range);
+        HandledUncountableEarlyExit = true;
+        continue;
       }
+
+      // Otherwise all early exits must be countable and we require at least one
+      // iteration in the scalar epilogue. Disconnect all edges to exit blocks
+      // other than from the middle block.
+      cast<VPBasicBlock>(Pred)->getTerminator()->eraseFromParent();
+      VPBlockUtils::disconnectBlocks(Pred, EB);
     }
   }
 
+  assert((!HasUncountableEarlyExit || HandledUncountableEarlyExit) &&
+         "did not handle uncountable early exit");
+
   // Create SCEV and VPValue for the trip count.
   // We use the symbolic max backedge-taken-count, which works also when
   // vectorizing loops with uncountable early exits.
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index 3a6c5bc02cdf1..92ff1fd05fc64 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -2458,68 +2458,42 @@ void VPlanTransforms::convertToConcreteRecipes(VPlan &Plan,
     R->eraseFromParent();
 }
 
-void VPlanTransforms::handleUncountableEarlyExit(VPlan &Plan,
+void VPlanTransforms::handleUncountableEarlyExit(VPBasicBlock *EarlyExitingVPBB,
+                                                 VPBasicBlock *EarlyExitVPBB,
+
+                                                 VPlan &Plan,
                                                  VPBasicBlock *HeaderVPBB,
                                                  VPBasicBlock *LatchVPBB,
                                                  VFRange &Range) {
-  // First find the uncountable early exiting block by looking at the
-  // predecessors of the exit blocks.
-  VPBlockBase *MiddleVPBB = LatchVPBB->getSuccessors()[0];
-  VPBasicBlock *EarlyExitingVPBB = nullptr;
-  VPIRBasicBlock *EarlyExitVPBB = nullptr;
-  for (auto *EB : Plan.getExitBlocks()) {
-    for (VPBlockBase *Pred : EB->getPredecessors()) {
-      if (Pred != MiddleVPBB) {
-        EarlyExitingVPBB = cast<VPBasicBlock>(Pred);
-        EarlyExitVPBB = EB;
-        break;
-      }
-    }
-  }
-  assert(EarlyExitVPBB && "Must have a early exiting block.");
-  assert(all_of(Plan.getExitBlocks(),
-                [EarlyExitingVPBB, MiddleVPBB](VPIRBasicBlock *EB) {
-                  return all_of(
-                      EB->getPredecessors(),
-                      [EarlyExitingVPBB, MiddleVPBB](VPBlockBase *Pred) {
-                        return Pred == EarlyExitingVPBB || Pred == MiddleVPBB;
-                      });
-                }) &&
-         "All exit blocks must only have EarlyExitingVPBB or MiddleVPBB as "
-         "predecessors.");
-
-  VPBuilder Builder(LatchVPBB->getTerminator());
-  VPBlockBase *TrueSucc = EarlyExitingVPBB->getSuccessors()[0];
-  VPValue *EarlyExitCond = EarlyExitingVPBB->getTerminator()->getOperand(0);
-  auto *EarlyExitTakenCond = TrueSucc == EarlyExitVPBB
-                                 ? EarlyExitCond
-                                 : Builder.createNot(EarlyExitCond);
+  using namespace llvm::VPlanPatternMatch;
 
+  VPBlockBase *MiddleVPBB = LatchVPBB->getSuccessors()[0];
   if (!EarlyExitVPBB->getSinglePredecessor() &&
       EarlyExitVPBB->getPredecessors()[0] != MiddleVPBB) {
-    for (VPRecipeBase &R : EarlyExitVPBB->phis()) {
-      // Early exit operand should always be last, i.e., 0 if EarlyExitVPBB has
-      // a single predecessor and 1 if it has two.
-      // If EarlyExitVPBB has two predecessors, they are already ordered such
-      // that early exit is second (and latch exit is first), by construction.
-      // But its underlying IRBB (EarlyExitIRBB) may have its predecessors
-      // ordered the other way around, and it is the order of the latter which
-      // corresponds to the order of operands of EarlyExitVPBB's phi recipes.
-      // Therefore, if early exit (UncountableExitingBlock) is the first
-      // predecessor of EarlyExitIRBB, we swap the operands of phi recipes,
-      // thereby bringing them to match EarlyExitVPBB's predecessor order,
-      // with early exit being last (second). Otherwise they already match.
+    // Early exit operand should always be last phi operand. If EarlyExitVPBB
+    // has two predecessors and MiddleVPBB isn't the first, swap the operands of
+    // the phis.
+    for (VPRecipeBase &R : EarlyExitVPBB->phis())
       cast<VPIRPhi>(&R)->swapOperands();
-    }
   }
 
+  VPBuilder Builder(LatchVPBB->getTerminator());
+  VPBlockBase *TrueSucc = EarlyExitingVPBB->getSuccessors()[0];
+  assert(
+      match(EarlyExitingVPBB->getTerminator(), m_BranchOnCond(m_VPValue())) &&
+      "Terminator must be be BranchOnCond");
+  VPValue *CondOfEarlyExitingVPBB =
+      EarlyExitingVPBB->getTerminator()->getOperand(0);
+  auto *CondToEarlyExit = TrueSucc == EarlyExitVPBB
+                              ? CondOfEarlyExitingVPBB
+                              : Builder.createNot(CondOfEarlyExitingVPBB);
   EarlyExitingVPBB->getTerminator()->eraseFromParent();
   VPBlockUtils::disconnectBlocks(EarlyExitingVPBB, EarlyExitVPBB);
 
   // Split the middle block and have it conditionally branch to the early exit
   // block if EarlyExitTaken.
   VPValue *IsEarlyExitTaken =
-      Builder.createNaryOp(VPInstruction::AnyOf, {EarlyExitTakenCond});
+      Builder.createNaryOp(VPInstruction::AnyOf, {CondToEarlyExit});
   VPBasicBlock *NewMiddle = Plan.createVPBasicBlock("middle.split");
   VPBasicBlock *VectorEarlyExitVPBB =
       Plan.createVPBasicBlock("vector.early.exit");
@@ -2537,7 +2511,7 @@ void VPlanTransforms::handleUncountableEarlyExit(VPlan &Plan,
     // Early exit operand should always be last, i.e., 0 if EarlyExitVPBB has
     // a single predecessor and 1 if it has two.
     unsigned EarlyExitIdx = ExitIRI->getNumOperands() - 1;
-    if (!EarlyExitVPBB->getSinglePredecessor()) {
+    if (ExitIRI->getNumOperands() != 1) {
       // The first of two operands corresponds to the latch exit, via MiddleVPBB
       // predecessor. Extract its last lane.
       ExitIRI->extractLastLaneOfFirstOperand(MiddleBuilder);
@@ -2553,7 +2527,7 @@ void VPlanTransforms::handleUncountableEarlyExit(VPlan &Plan,
         LoopVectorizationPlanner::getDecisionAndClampRange(IsVector, Range)) {
       // Update the incoming value from the early exit.
       VPValue *FirstActiveLane = EarlyExitB.createNaryOp(
-          VPInstruction::FirstActiveLane, {EarlyExitTakenCond}, nullptr,
+          VPInstruction::FirstActiveLane, {CondToEarlyExit}, nullptr,
           "first.active.lane");
       IncomingFromEarlyExit = EarlyExitB.createNaryOp(
           Instruction::ExtractElement, {IncomingFromEarlyExit, FirstActiveLane},
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.h b/llvm/lib/Transforms/Vectorize/VPlanTransforms.h
index adb984fc56bac..2992bc56d8ac8 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.h
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.h
@@ -174,13 +174,15 @@ struct VPlanTransforms {
   /// Remove dead recipes from \p Plan.
   static void removeDeadRecipes(VPlan &Plan);
 
-  /// Update \p Plan to account for the uncountable early exit block in \p
-  /// UncountableExitingBlock by
+  /// Update \p Plan to account for the uncountable early exit from \p
+  /// EarlyExitingVPBB to \p EarlyExitVPBB by
   ///  * updating the condition exiting the vector loop to include the early
   ///    exit conditions
   ///  * splitting the original middle block to branch to the early exit block
   ///    if taken.
-  static void handleUncountableEarlyExit(VPlan &Plan, VPBasicBlock *HeaderVPBB,
+  static void handleUncountableEarlyExit(VPBasicBlock *EarlyExitingVPBB,
+                                         VPBasicBlock *EarlyExitVPBB,
+                                         VPlan &Plan, VPBasicBlock *HeaderVPBB,
                                          VPBasicBlock *LatchVPBB,
                                          VFRange &Range);
 

>From 56d576a99a564b4a1c0dfe0d61d2f19d1d24888e Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Fri, 9 May 2025 16:16:40 +0100
Subject: [PATCH 3/6] !fixup address comments, thanks

---
 llvm/lib/Transforms/Vectorize/VPlanConstruction.cpp | 3 +++
 llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp   | 3 +++
 2 files changed, 6 insertions(+)

diff --git a/llvm/lib/Transforms/Vectorize/VPlanConstruction.cpp b/llvm/lib/Transforms/Vectorize/VPlanConstruction.cpp
index 4270564fccec0..3a2dc791b024d 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanConstruction.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanConstruction.cpp
@@ -491,6 +491,9 @@ void VPlanTransforms::prepareForVectorization(
                         cast<VPBasicBlock>(LatchVPB), InductionTy, IVDL);
 
   [[maybe_unused]] bool HandledUncountableEarlyExit = false;
+  // Handle the remaining early exits, either by converting the plan to one only
+  // exiting via the latch or by disconnecting all early exiting edges and
+  // requiring a scalar epilogue.
   for (VPIRBasicBlock *EB : Plan.getExitBlocks()) {
     for (VPBlockBase *Pred : to_vector(EB->getPredecessors())) {
       if (Pred == MiddleVPBB)
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index 20041fb1194b1..7ec4faee08f62 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -2472,6 +2472,9 @@ void VPlanTransforms::handleUncountableEarlyExit(VPBasicBlock *EarlyExitingVPBB,
   VPBlockBase *MiddleVPBB = LatchVPBB->getSuccessors()[0];
   if (!EarlyExitVPBB->getSinglePredecessor() &&
       EarlyExitVPBB->getPredecessors()[0] != MiddleVPBB) {
+    assert(EarlyExitVPBB->getNumPredecessors() == 2 &&
+           EarlyExitVPBB->getPredecessors()[1] == MiddleVPBB &&
+           "unsupported earl exit VPBB");
     // Early exit operand should always be last phi operand. If EarlyExitVPBB
     // has two predecessors and MiddleVPBB isn't the first, swap the operands of
     // the phis.

>From 2289a5e0d01b26c0818000001604ca316efe3e3e Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Fri, 9 May 2025 19:49:17 +0100
Subject: [PATCH 4/6] !fixup address latest comments, thanks

---
 .../Transforms/Vectorize/VPlanConstruction.cpp  | 17 ++++++-----------
 .../Transforms/Vectorize/VPlanTransforms.cpp    | 15 ++++++---------
 llvm/lib/Transforms/Vectorize/VPlanTransforms.h |  6 +++---
 3 files changed, 15 insertions(+), 23 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/VPlanConstruction.cpp b/llvm/lib/Transforms/Vectorize/VPlanConstruction.cpp
index 3a2dc791b024d..b924b14035261 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanConstruction.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanConstruction.cpp
@@ -491,36 +491,31 @@ void VPlanTransforms::prepareForVectorization(
                         cast<VPBasicBlock>(LatchVPB), InductionTy, IVDL);
 
   [[maybe_unused]] bool HandledUncountableEarlyExit = false;
-  // Handle the remaining early exits, either by converting the plan to one only
-  // exiting via the latch or by disconnecting all early exiting edges and
-  // requiring a scalar epilogue.
+  // Disconnect all early exits from the loop leaving it with a single exit from
+  // the latch. Early exits that are countable are left for a scalar epilog. The
+  // condition of uncountable early exits (currently at most one is supported)
+  // is fused into the latch exit, and used to branch from middle block to the
+  // early exit destination.
   for (VPIRBasicBlock *EB : Plan.getExitBlocks()) {
     for (VPBlockBase *Pred : to_vector(EB->getPredecessors())) {
       if (Pred == MiddleVPBB)
         continue;
-
       if (HasUncountableEarlyExit) {
         assert(!HandledUncountableEarlyExit &&
                "can handle exactly one uncountable early exit");
-        // Convert VPlans with early exits to a form exiting only via the latch
-        // here, including adjusting the exit condition of the latch.
         handleUncountableEarlyExit(cast<VPBasicBlock>(Pred), EB, Plan,
                                    cast<VPBasicBlock>(HeaderVPB),
                                    cast<VPBasicBlock>(LatchVPB), Range);
         HandledUncountableEarlyExit = true;
-        continue;
       }
 
-      // Otherwise all early exits must be countable and we require at least one
-      // iteration in the scalar epilogue. Disconnect all edges to exit blocks
-      // other than from the middle block.
       cast<VPBasicBlock>(Pred)->getTerminator()->eraseFromParent();
       VPBlockUtils::disconnectBlocks(Pred, EB);
     }
   }
 
   assert((!HasUncountableEarlyExit || HandledUncountableEarlyExit) &&
-         "did not handle uncountable early exit");
+         "missed an uncountable exit that must be handled");
 
   // Create SCEV and VPValue for the trip count.
   // We use the symbolic max backedge-taken-count, which works also when
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index 7ec4faee08f62..c195efe86c806 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -2462,7 +2462,6 @@ void VPlanTransforms::convertToConcreteRecipes(VPlan &Plan,
 
 void VPlanTransforms::handleUncountableEarlyExit(VPBasicBlock *EarlyExitingVPBB,
                                                  VPBasicBlock *EarlyExitVPBB,
-
                                                  VPlan &Plan,
                                                  VPBasicBlock *HeaderVPBB,
                                                  VPBasicBlock *LatchVPBB,
@@ -2471,13 +2470,13 @@ void VPlanTransforms::handleUncountableEarlyExit(VPBasicBlock *EarlyExitingVPBB,
 
   VPBlockBase *MiddleVPBB = LatchVPBB->getSuccessors()[0];
   if (!EarlyExitVPBB->getSinglePredecessor() &&
-      EarlyExitVPBB->getPredecessors()[0] != MiddleVPBB) {
+      EarlyExitVPBB->getPredecessors()[1] == MiddleVPBB) {
     assert(EarlyExitVPBB->getNumPredecessors() == 2 &&
-           EarlyExitVPBB->getPredecessors()[1] == MiddleVPBB &&
-           "unsupported earl exit VPBB");
+           EarlyExitVPBB->getPredecessors()[0] == EarlyExitingVPBB &&
+           "unsupported early exit VPBB");
     // Early exit operand should always be last phi operand. If EarlyExitVPBB
-    // has two predecessors and MiddleVPBB isn't the first, swap the operands of
-    // the phis.
+    // has two predecessors and EarlyExitingVPBB is the first, swap the operands
+    // of the phis.
     for (VPRecipeBase &R : EarlyExitVPBB->phis())
       cast<VPIRPhi>(&R)->swapOperands();
   }
@@ -2492,11 +2491,9 @@ void VPlanTransforms::handleUncountableEarlyExit(VPBasicBlock *EarlyExitingVPBB,
   auto *CondToEarlyExit = TrueSucc == EarlyExitVPBB
                               ? CondOfEarlyExitingVPBB
                               : Builder.createNot(CondOfEarlyExitingVPBB);
-  EarlyExitingVPBB->getTerminator()->eraseFromParent();
-  VPBlockUtils::disconnectBlocks(EarlyExitingVPBB, EarlyExitVPBB);
 
   // Split the middle block and have it conditionally branch to the early exit
-  // block if EarlyExitTaken.
+  // block if CondToEarlyExit.
   VPValue *IsEarlyExitTaken =
       Builder.createNaryOp(VPInstruction::AnyOf, {CondToEarlyExit});
   VPBasicBlock *NewMiddle = Plan.createVPBasicBlock("middle.split");
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.h b/llvm/lib/Transforms/Vectorize/VPlanTransforms.h
index 2992bc56d8ac8..530e06d983e23 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.h
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.h
@@ -69,7 +69,7 @@ struct VPlanTransforms {
                                       PredicatedScalarEvolution &PSE,
                                       bool RequiresScalarEpilogueCheck,
                                       bool TailFolded, Loop *TheLoop,
-                                      DebugLoc IVDL, bool HandleUncountableExit,
+                                      DebugLoc IVDL, bool HasUncountableExit,
                                       VFRange &Range);
 
   /// Replace loops in \p Plan's flat CFG with VPRegionBlocks, turning \p Plan's
@@ -177,9 +177,9 @@ struct VPlanTransforms {
   /// Update \p Plan to account for the uncountable early exit from \p
   /// EarlyExitingVPBB to \p EarlyExitVPBB by
   ///  * updating the condition exiting the vector loop to include the early
-  ///    exit conditions
+  ///    exit condition,
   ///  * splitting the original middle block to branch to the early exit block
-  ///    if taken.
+  ///    conditionally - according to the early exit condition.
   static void handleUncountableEarlyExit(VPBasicBlock *EarlyExitingVPBB,
                                          VPBasicBlock *EarlyExitVPBB,
                                          VPlan &Plan, VPBasicBlock *HeaderVPBB,

>From 474767852ae56842733123a1d487ff6fc2077d5a Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Fri, 9 May 2025 20:05:52 +0100
Subject: [PATCH 5/6] !fixup  fix formatting

---
 llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp | 9 +++------
 1 file changed, 3 insertions(+), 6 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index c195efe86c806..806c20ef8cf73 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -2460,12 +2460,9 @@ void VPlanTransforms::convertToConcreteRecipes(VPlan &Plan,
     R->eraseFromParent();
 }
 
-void VPlanTransforms::handleUncountableEarlyExit(VPBasicBlock *EarlyExitingVPBB,
-                                                 VPBasicBlock *EarlyExitVPBB,
-                                                 VPlan &Plan,
-                                                 VPBasicBlock *HeaderVPBB,
-                                                 VPBasicBlock *LatchVPBB,
-                                                 VFRange &Range) {
+void VPlanTransforms::handleUncountableEarlyExit(
+    VPBasicBlock *EarlyExitingVPBB, VPBasicBlock *EarlyExitVPBB, VPlan &Plan,
+    VPBasicBlock *HeaderVPBB, VPBasicBlock *LatchVPBB, VFRange &Range) {
   using namespace llvm::VPlanPatternMatch;
 
   VPBlockBase *MiddleVPBB = LatchVPBB->getSuccessors()[0];

>From 4129042d604f8aa59d0021c0c789b9009449de52 Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Sat, 22 Feb 2025 19:15:32 +0000
Subject: [PATCH 6/6] [VPlan] Move predication to VPlanTransform (NFC) (WIP).

This patch moves the logic to predicate and linearize a VPlan to a
dedicated VPlan transform.

The main logic to perform predication is ready to review, although
there are few things to note that should be improved, either directly in
the PR or in the future:
 * Edge and block masks are cached in VPRecipeBuilder, so they can be
   accessed during recipe construction. A better alternative may be to
   add mask operands to all VPInstructions that need them and use that
   during recipe construction
 * The mask caching in a map also means that this map needs updating
   each time a new recipe replaces a VPInstruction; this would also be
   handled by adding mask operands.

Currently this is still WIP due to early-exit loop handling not working
due to the exit conditions not being available in the initial VPlans.
This will be fixed with https://github.com/llvm/llvm-project/pull/128419
and follow-ups

All tests except early-exit loops are passing
---
 llvm/lib/Transforms/Vectorize/CMakeLists.txt  |   1 +
 .../Transforms/Vectorize/LoopVectorize.cpp    | 314 +++---------------
 .../Transforms/Vectorize/VPRecipeBuilder.h    |  49 +--
 .../Vectorize/VPlanConstruction.cpp           |  16 +-
 .../Transforms/Vectorize/VPlanPredicator.cpp  | 310 +++++++++++++++++
 .../Transforms/Vectorize/VPlanTransforms.h    |  14 +-
 .../Transforms/Vectorize/VPlanTestBase.h      |   3 +-
 7 files changed, 379 insertions(+), 328 deletions(-)
 create mode 100644 llvm/lib/Transforms/Vectorize/VPlanPredicator.cpp

diff --git a/llvm/lib/Transforms/Vectorize/CMakeLists.txt b/llvm/lib/Transforms/Vectorize/CMakeLists.txt
index 0dc6a7d2f594f..e6c7142edd100 100644
--- a/llvm/lib/Transforms/Vectorize/CMakeLists.txt
+++ b/llvm/lib/Transforms/Vectorize/CMakeLists.txt
@@ -24,6 +24,7 @@ add_llvm_component_library(LLVMVectorize
   VPlan.cpp
   VPlanAnalysis.cpp
   VPlanConstruction.cpp
+  VPlanPredicator.cpp
   VPlanRecipes.cpp
   VPlanSLP.cpp
   VPlanTransforms.cpp
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 93eddd87d17bf..b6e0dcb4b9930 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -8215,185 +8215,6 @@ void EpilogueVectorizerEpilogueLoop::printDebugTracesAtEnd() {
   });
 }
 
-void VPRecipeBuilder::createSwitchEdgeMasks(SwitchInst *SI) {
-  BasicBlock *Src = SI->getParent();
-  assert(!OrigLoop->isLoopExiting(Src) &&
-         all_of(successors(Src),
-                [this](BasicBlock *Succ) {
-                  return OrigLoop->getHeader() != Succ;
-                }) &&
-         "unsupported switch either exiting loop or continuing to header");
-  // Create masks where the terminator in Src is a switch. We create mask for
-  // all edges at the same time. This is more efficient, as we can create and
-  // collect compares for all cases once.
-  VPValue *Cond = getVPValueOrAddLiveIn(SI->getCondition());
-  BasicBlock *DefaultDst = SI->getDefaultDest();
-  MapVector<BasicBlock *, SmallVector<VPValue *>> Dst2Compares;
-  for (auto &C : SI->cases()) {
-    BasicBlock *Dst = C.getCaseSuccessor();
-    assert(!EdgeMaskCache.contains({Src, Dst}) && "Edge masks already created");
-    // Cases whose destination is the same as default are redundant and can be
-    // ignored - they will get there anyhow.
-    if (Dst == DefaultDst)
-      continue;
-    auto &Compares = Dst2Compares[Dst];
-    VPValue *V = getVPValueOrAddLiveIn(C.getCaseValue());
-    Compares.push_back(Builder.createICmp(CmpInst::ICMP_EQ, Cond, V));
-  }
-
-  // We need to handle 2 separate cases below for all entries in Dst2Compares,
-  // which excludes destinations matching the default destination.
-  VPValue *SrcMask = getBlockInMask(Src);
-  VPValue *DefaultMask = nullptr;
-  for (const auto &[Dst, Conds] : Dst2Compares) {
-    // 1. Dst is not the default destination. Dst is reached if any of the cases
-    // with destination == Dst are taken. Join the conditions for each case
-    // whose destination == Dst using an OR.
-    VPValue *Mask = Conds[0];
-    for (VPValue *V : ArrayRef<VPValue *>(Conds).drop_front())
-      Mask = Builder.createOr(Mask, V);
-    if (SrcMask)
-      Mask = Builder.createLogicalAnd(SrcMask, Mask);
-    EdgeMaskCache[{Src, Dst}] = Mask;
-
-    // 2. Create the mask for the default destination, which is reached if none
-    // of the cases with destination != default destination are taken. Join the
-    // conditions for each case where the destination is != Dst using an OR and
-    // negate it.
-    DefaultMask = DefaultMask ? Builder.createOr(DefaultMask, Mask) : Mask;
-  }
-
-  if (DefaultMask) {
-    DefaultMask = Builder.createNot(DefaultMask);
-    if (SrcMask)
-      DefaultMask = Builder.createLogicalAnd(SrcMask, DefaultMask);
-  }
-  EdgeMaskCache[{Src, DefaultDst}] = DefaultMask;
-}
-
-VPValue *VPRecipeBuilder::createEdgeMask(BasicBlock *Src, BasicBlock *Dst) {
-  assert(is_contained(predecessors(Dst), Src) && "Invalid edge");
-
-  // Look for cached value.
-  std::pair<BasicBlock *, BasicBlock *> Edge(Src, Dst);
-  EdgeMaskCacheTy::iterator ECEntryIt = EdgeMaskCache.find(Edge);
-  if (ECEntryIt != EdgeMaskCache.end())
-    return ECEntryIt->second;
-
-  if (auto *SI = dyn_cast<SwitchInst>(Src->getTerminator())) {
-    createSwitchEdgeMasks(SI);
-    assert(EdgeMaskCache.contains(Edge) && "Mask for Edge not created?");
-    return EdgeMaskCache[Edge];
-  }
-
-  VPValue *SrcMask = getBlockInMask(Src);
-
-  // The terminator has to be a branch inst!
-  BranchInst *BI = dyn_cast<BranchInst>(Src->getTerminator());
-  assert(BI && "Unexpected terminator found");
-  if (!BI->isConditional() || BI->getSuccessor(0) == BI->getSuccessor(1))
-    return EdgeMaskCache[Edge] = SrcMask;
-
-  // If source is an exiting block, we know the exit edge is dynamically dead
-  // in the vector loop, and thus we don't need to restrict the mask.  Avoid
-  // adding uses of an otherwise potentially dead instruction unless we are
-  // vectorizing a loop with uncountable exits. In that case, we always
-  // materialize the mask.
-  if (OrigLoop->isLoopExiting(Src) &&
-      Src != Legal->getUncountableEarlyExitingBlock())
-    return EdgeMaskCache[Edge] = SrcMask;
-
-  VPValue *EdgeMask = getVPValueOrAddLiveIn(BI->getCondition());
-  assert(EdgeMask && "No Edge Mask found for condition");
-
-  if (BI->getSuccessor(0) != Dst)
-    EdgeMask = Builder.createNot(EdgeMask, BI->getDebugLoc());
-
-  if (SrcMask) { // Otherwise block in-mask is all-one, no need to AND.
-    // The bitwise 'And' of SrcMask and EdgeMask introduces new UB if SrcMask
-    // is false and EdgeMask is poison. Avoid that by using 'LogicalAnd'
-    // instead which generates 'select i1 SrcMask, i1 EdgeMask, i1 false'.
-    EdgeMask = Builder.createLogicalAnd(SrcMask, EdgeMask, BI->getDebugLoc());
-  }
-
-  return EdgeMaskCache[Edge] = EdgeMask;
-}
-
-VPValue *VPRecipeBuilder::getEdgeMask(BasicBlock *Src, BasicBlock *Dst) const {
-  assert(is_contained(predecessors(Dst), Src) && "Invalid edge");
-
-  // Look for cached value.
-  std::pair<BasicBlock *, BasicBlock *> Edge(Src, Dst);
-  EdgeMaskCacheTy::const_iterator ECEntryIt = EdgeMaskCache.find(Edge);
-  assert(ECEntryIt != EdgeMaskCache.end() &&
-         "looking up mask for edge which has not been created");
-  return ECEntryIt->second;
-}
-
-void VPRecipeBuilder::createHeaderMask() {
-  BasicBlock *Header = OrigLoop->getHeader();
-
-  // When not folding the tail, use nullptr to model all-true mask.
-  if (!CM.foldTailByMasking()) {
-    BlockMaskCache[Header] = nullptr;
-    return;
-  }
-
-  // Introduce the early-exit compare IV <= BTC to form header block mask.
-  // This is used instead of IV < TC because TC may wrap, unlike BTC. Start by
-  // constructing the desired canonical IV in the header block as its first
-  // non-phi instructions.
-
-  VPBasicBlock *HeaderVPBB = Plan.getVectorLoopRegion()->getEntryBasicBlock();
-  auto NewInsertionPoint = HeaderVPBB->getFirstNonPhi();
-  auto *IV = new VPWidenCanonicalIVRecipe(Plan.getCanonicalIV());
-  HeaderVPBB->insert(IV, NewInsertionPoint);
-
-  VPBuilder::InsertPointGuard Guard(Builder);
-  Builder.setInsertPoint(HeaderVPBB, NewInsertionPoint);
-  VPValue *BlockMask = nullptr;
-  VPValue *BTC = Plan.getOrCreateBackedgeTakenCount();
-  BlockMask = Builder.createICmp(CmpInst::ICMP_ULE, IV, BTC);
-  BlockMaskCache[Header] = BlockMask;
-}
-
-VPValue *VPRecipeBuilder::getBlockInMask(BasicBlock *BB) const {
-  // Return the cached value.
-  BlockMaskCacheTy::const_iterator BCEntryIt = BlockMaskCache.find(BB);
-  assert(BCEntryIt != BlockMaskCache.end() &&
-         "Trying to access mask for block without one.");
-  return BCEntryIt->second;
-}
-
-void VPRecipeBuilder::createBlockInMask(BasicBlock *BB) {
-  assert(OrigLoop->contains(BB) && "Block is not a part of a loop");
-  assert(BlockMaskCache.count(BB) == 0 && "Mask for block already computed");
-  assert(OrigLoop->getHeader() != BB &&
-         "Loop header must have cached block mask");
-
-  // All-one mask is modelled as no-mask following the convention for masked
-  // load/store/gather/scatter. Initialize BlockMask to no-mask.
-  VPValue *BlockMask = nullptr;
-  // This is the block mask. We OR all unique incoming edges.
-  for (auto *Predecessor :
-       SetVector<BasicBlock *>(llvm::from_range, predecessors(BB))) {
-    VPValue *EdgeMask = createEdgeMask(Predecessor, BB);
-    if (!EdgeMask) { // Mask of predecessor is all-one so mask of block is too.
-      BlockMaskCache[BB] = EdgeMask;
-      return;
-    }
-
-    if (!BlockMask) { // BlockMask has its initialized nullptr value.
-      BlockMask = EdgeMask;
-      continue;
-    }
-
-    BlockMask = Builder.createOr(BlockMask, EdgeMask, {});
-  }
-
-  BlockMaskCache[BB] = BlockMask;
-}
-
 VPWidenMemoryRecipe *
 VPRecipeBuilder::tryToWidenMemory(Instruction *I, ArrayRef<VPValue *> Operands,
                                   VFRange &Range) {
@@ -8418,7 +8239,7 @@ VPRecipeBuilder::tryToWidenMemory(Instruction *I, ArrayRef<VPValue *> Operands,
 
   VPValue *Mask = nullptr;
   if (Legal->isMaskRequired(I))
-    Mask = getBlockInMask(I->getParent());
+    Mask = getBlockInMask(Builder.getInsertBlock());
 
   // Determine if the pointer operand of the access is either consecutive or
   // reverse consecutive.
@@ -8538,38 +8359,6 @@ VPWidenIntOrFpInductionRecipe *VPRecipeBuilder::tryToOptimizeInductionTruncate(
   return nullptr;
 }
 
-VPBlendRecipe *VPRecipeBuilder::tryToBlend(PHINode *Phi,
-                                           ArrayRef<VPValue *> Operands) {
-  unsigned NumIncoming = Phi->getNumIncomingValues();
-
-  // We know that all PHIs in non-header blocks are converted into selects, so
-  // we don't have to worry about the insertion order and we can just use the
-  // builder. At this point we generate the predication tree. There may be
-  // duplications since this is a simple recursive scan, but future
-  // optimizations will clean it up.
-
-  // Map incoming IR BasicBlocks to incoming VPValues, for lookup below.
-  // TODO: Add operands and masks in order from the VPlan predecessors.
-  DenseMap<BasicBlock *, VPValue *> VPIncomingValues;
-  for (const auto &[Idx, Pred] : enumerate(predecessors(Phi->getParent())))
-    VPIncomingValues[Pred] = Operands[Idx];
-
-  SmallVector<VPValue *, 2> OperandsWithMask;
-  for (unsigned In = 0; In < NumIncoming; In++) {
-    BasicBlock *Pred = Phi->getIncomingBlock(In);
-    OperandsWithMask.push_back(VPIncomingValues.lookup(Pred));
-    VPValue *EdgeMask = getEdgeMask(Pred, Phi->getParent());
-    if (!EdgeMask) {
-      assert(In == 0 && "Both null and non-null edge masks found");
-      assert(all_equal(Operands) &&
-             "Distinct incoming values with one having a full mask");
-      break;
-    }
-    OperandsWithMask.push_back(EdgeMask);
-  }
-  return new VPBlendRecipe(Phi, OperandsWithMask);
-}
-
 VPSingleDefRecipe *VPRecipeBuilder::tryToWidenCall(CallInst *CI,
                                                    ArrayRef<VPValue *> Operands,
                                                    VFRange &Range) {
@@ -8645,7 +8434,7 @@ VPSingleDefRecipe *VPRecipeBuilder::tryToWidenCall(CallInst *CI,
       //      all-true mask.
       VPValue *Mask = nullptr;
       if (Legal->isMaskRequired(CI))
-        Mask = getBlockInMask(CI->getParent());
+        Mask = getBlockInMask(Builder.getInsertBlock());
       else
         Mask = Plan.getOrAddLiveIn(
             ConstantInt::getTrue(IntegerType::getInt1Ty(CI->getContext())));
@@ -8687,7 +8476,7 @@ VPWidenRecipe *VPRecipeBuilder::tryToWiden(Instruction *I,
     // div/rem operation itself.  Otherwise fall through to general handling below.
     if (CM.isPredicatedInst(I)) {
       SmallVector<VPValue *> Ops(Operands);
-      VPValue *Mask = getBlockInMask(I->getParent());
+      VPValue *Mask = getBlockInMask(Builder.getInsertBlock());
       VPValue *One =
           Plan.getOrAddLiveIn(ConstantInt::get(I->getType(), 1u, false));
       auto *SafeRHS = Builder.createSelect(Mask, Ops[1], One, I->getDebugLoc());
@@ -8769,7 +8558,7 @@ VPRecipeBuilder::tryToWidenHistogram(const HistogramInfo *HI,
   // In case of predicated execution (due to tail-folding, or conditional
   // execution, or both), pass the relevant mask.
   if (Legal->isMaskRequired(HI->Store))
-    HGramOps.push_back(getBlockInMask(HI->Store->getParent()));
+    HGramOps.push_back(getBlockInMask(Builder.getInsertBlock()));
 
   return new VPHistogramRecipe(Opcode, HGramOps, HI->Store->getDebugLoc());
 }
@@ -8823,7 +8612,7 @@ VPRecipeBuilder::handleReplication(Instruction *I, ArrayRef<VPValue *> Operands,
     // added initially. Masked replicate recipes will later be placed under an
     // if-then construct to prevent side-effects. Generate recipes to compute
     // the block mask for this region.
-    BlockInMask = getBlockInMask(I->getParent());
+    BlockInMask = getBlockInMask(Builder.getInsertBlock());
   }
 
   // Note that there is some custom logic to mark some intrinsics as uniform
@@ -8960,9 +8749,8 @@ VPRecipeBase *VPRecipeBuilder::tryToCreateWidenRecipe(
   // nodes, calls and memory operations.
   VPRecipeBase *Recipe;
   if (auto *Phi = dyn_cast<PHINode>(Instr)) {
-    if (Phi->getParent() != OrigLoop->getHeader())
-      return tryToBlend(Phi, Operands);
-
+    assert(Phi->getParent() == OrigLoop->getHeader() &&
+           "Non-header phis should have been handled during predication");
     assert(Operands.size() == 2 && "Must have 2 operands for header phis");
     if ((Recipe = tryToOptimizeInductionPHI(Phi, Operands, Range)))
       return Recipe;
@@ -9067,7 +8855,7 @@ VPRecipeBuilder::tryToCreatePartialReduction(Instruction *Reduction,
             ReductionOpcode == Instruction::Sub) &&
            "Expected an ADD or SUB operation for predicated partial "
            "reductions (because the neutral element in the mask is zero)!");
-    Cond = getBlockInMask(Reduction->getParent());
+    Cond = getBlockInMask(Builder.getInsertBlock());
     VPValue *Zero =
         Plan.getOrAddLiveIn(ConstantInt::get(Reduction->getType(), 0));
     BinOp = Builder.createSelect(Cond, BinOp, Zero, Reduction->getDebugLoc());
@@ -9378,8 +9166,7 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range,
             return !CM.requiresScalarEpilogue(VF.isVector());
           },
           Range);
-  DenseMap<VPBlockBase *, BasicBlock *> VPB2IRBB;
-  auto Plan = VPlanTransforms::buildPlainCFG(OrigLoop, *LI, VPB2IRBB);
+  auto Plan = VPlanTransforms::buildPlainCFG(OrigLoop, *LI);
   VPlanTransforms::prepareForVectorization(
       *Plan, Legal->getWidestInductionType(), PSE, RequiresScalarEpilogueCheck,
       CM.foldTailByMasking(), OrigLoop,
@@ -9412,9 +9199,6 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range,
     cast<VPRecipeWithIRFlags>(IVInc)->dropPoisonGeneratingFlags();
   }
 
-  VPRecipeBuilder RecipeBuilder(*Plan, OrigLoop, TLI, &TTI, Legal, CM, PSE,
-                                Builder, LVer);
-
   // ---------------------------------------------------------------------------
   // Pre-construction: record ingredients whose recipes we'll need to further
   // process after constructing the initial VPlan.
@@ -9442,44 +9226,29 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range,
   }
 
   // ---------------------------------------------------------------------------
-  // Construct recipes for the instructions in the loop
+  // Predicate and linearize the top-level loop region.
   // ---------------------------------------------------------------------------
+  DenseMap<VPBasicBlock *, VPValue *> BlockMaskCache;
+  VPlanTransforms::predicateAndLinearize(*Plan, CM.foldTailByMasking(),
+                                         BlockMaskCache);
 
-  VPRegionBlock *LoopRegion = Plan->getVectorLoopRegion();
-  VPBasicBlock *HeaderVPBB = LoopRegion->getEntryBasicBlock();
-  BasicBlock *HeaderBB = OrigLoop->getHeader();
-  bool NeedsMasks =
-      CM.foldTailByMasking() ||
-      any_of(OrigLoop->blocks(), [this, HeaderBB](BasicBlock *BB) {
-        bool NeedsBlends = BB != HeaderBB && !BB->phis().empty();
-        return Legal->blockNeedsPredication(BB) || NeedsBlends;
-      });
-
+  // ---------------------------------------------------------------------------
+  // Construct recipes for the instructions in the loop
+  // ---------------------------------------------------------------------------
+  VPRecipeBuilder RecipeBuilder(*Plan, OrigLoop, TLI, &TTI, Legal, CM, PSE,
+                                Builder, BlockMaskCache, LVer);
   RecipeBuilder.collectScaledReductions(Range);
 
-  auto *MiddleVPBB = Plan->getMiddleBlock();
-
   // Scan the body of the loop in a topological order to visit each basic block
   // after having visited its predecessor basic blocks.
+  VPRegionBlock *LoopRegion = Plan->getVectorLoopRegion();
+  VPBasicBlock *HeaderVPBB = LoopRegion->getEntryBasicBlock();
   ReversePostOrderTraversal<VPBlockShallowTraversalWrapper<VPBlockBase *>> RPOT(
       HeaderVPBB);
 
+  auto *MiddleVPBB = Plan->getMiddleBlock();
   VPBasicBlock::iterator MBIP = MiddleVPBB->getFirstNonPhi();
-  VPBlockBase *PrevVPBB = nullptr;
   for (VPBasicBlock *VPBB : VPBlockUtils::blocksOnly<VPBasicBlock>(RPOT)) {
-    // Create mask based on the IR BB corresponding to VPBB.
-    // TODO: Predicate directly based on VPlan.
-    Builder.setInsertPoint(VPBB, VPBB->begin());
-    if (VPBB == HeaderVPBB) {
-      Builder.setInsertPoint(VPBB, VPBB->getFirstNonPhi());
-      RecipeBuilder.createHeaderMask();
-    } else if (NeedsMasks) {
-      // FIXME: At the moment, masks need to be placed at the beginning of the
-      // block, as blends introduced for phi nodes need to use it. The created
-      // blends should be sunk after the mask recipes.
-      RecipeBuilder.createBlockInMask(VPB2IRBB.lookup(VPBB));
-    }
-
     // Convert input VPInstructions to widened recipes.
     for (VPRecipeBase &R : make_early_inc_range(*VPBB)) {
       auto *SingleDef = cast<VPSingleDefRecipe>(&R);
@@ -9489,7 +9258,8 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range,
       // latter are added above for masking.
       // FIXME: Migrate code relying on the underlying instruction from VPlan0
       // to construct recipes below to not use the underlying instruction.
-      if (isa<VPCanonicalIVPHIRecipe, VPWidenCanonicalIVRecipe>(&R) ||
+      if (isa<VPCanonicalIVPHIRecipe, VPWidenCanonicalIVRecipe, VPBlendRecipe>(
+              &R) ||
           (isa<VPInstruction>(&R) && !UnderlyingValue))
         continue;
 
@@ -9498,14 +9268,6 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range,
       assert((isa<VPWidenPHIRecipe>(&R) || isa<VPInstruction>(&R)) &&
              UnderlyingValue && "unsupported recipe");
 
-      if (isa<VPInstruction>(&R) &&
-          (cast<VPInstruction>(&R)->getOpcode() ==
-               VPInstruction::BranchOnCond ||
-           (cast<VPInstruction>(&R)->getOpcode() == Instruction::Switch))) {
-        R.eraseFromParent();
-        break;
-      }
-
       // TODO: Gradually replace uses of underlying instruction by analyses on
       // VPlan.
       Instruction *Instr = cast<Instruction>(UnderlyingValue);
@@ -9542,22 +9304,19 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range,
       } else {
         Builder.insert(Recipe);
       }
-      if (Recipe->getNumDefinedValues() == 1)
+      if (Recipe->getNumDefinedValues() == 1) {
         SingleDef->replaceAllUsesWith(Recipe->getVPSingleValue());
-      else
+        // replaceAllUsesWith may invalidate the block mask cache. Update it.
+        // TODO: Include the masks as operands in the predicated VPlan directly
+        // to remove the need to keep a map of masks beyond the predication
+        // transform.
+        RecipeBuilder.updateBlockMaskCache(SingleDef,
+                                           Recipe->getVPSingleValue());
+      } else
         assert(Recipe->getNumDefinedValues() == 0 &&
                "Unexpected multidef recipe");
       R.eraseFromParent();
     }
-
-    // Flatten the CFG in the loop. Masks for blocks have already been generated
-    // and added to recipes as needed. To do so, first disconnect VPBB from its
-    // successors. Then connect VPBB to the previously visited VPBB.
-    for (auto *Succ : to_vector(VPBB->getSuccessors()))
-      VPBlockUtils::disconnectBlocks(VPBB, Succ);
-    if (PrevVPBB)
-      VPBlockUtils::connectBlocks(PrevVPBB, VPBB);
-    PrevVPBB = VPBB;
   }
 
   assert(isa<VPRegionBlock>(Plan->getVectorLoopRegion()) &&
@@ -9676,8 +9435,7 @@ VPlanPtr LoopVectorizationPlanner::tryToBuildVPlan(VFRange &Range) {
   assert(!OrigLoop->isInnermost());
   assert(EnableVPlanNativePath && "VPlan-native path is not enabled.");
 
-  DenseMap<VPBlockBase *, BasicBlock *> VPB2IRBB;
-  auto Plan = VPlanTransforms::buildPlainCFG(OrigLoop, *LI, VPB2IRBB);
+  auto Plan = VPlanTransforms::buildPlainCFG(OrigLoop, *LI);
   VPlanTransforms::prepareForVectorization(
       *Plan, Legal->getWidestInductionType(), PSE, true, false, OrigLoop,
       getDebugLocFromInstOrOperands(Legal->getPrimaryInduction()), false,
@@ -9697,8 +9455,9 @@ VPlanPtr LoopVectorizationPlanner::tryToBuildVPlan(VFRange &Range) {
 
   // Collect mapping of IR header phis to header phi recipes, to be used in
   // addScalarResumePhis.
+  DenseMap<VPBasicBlock *, VPValue *> BlockMaskCache;
   VPRecipeBuilder RecipeBuilder(*Plan, OrigLoop, TLI, &TTI, Legal, CM, PSE,
-                                Builder, nullptr /*LVer*/);
+                                Builder, BlockMaskCache, nullptr /*LVer*/);
   for (auto &R : Plan->getVectorLoopRegion()->getEntryBasicBlock()->phis()) {
     if (isa<VPCanonicalIVPHIRecipe>(&R))
       continue;
@@ -9846,7 +9605,7 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
       BasicBlock *BB = CurrentLinkI->getParent();
       VPValue *CondOp = nullptr;
       if (CM.blockNeedsPredicationForAnyReason(BB))
-        CondOp = RecipeBuilder.getBlockInMask(BB);
+        CondOp = RecipeBuilder.getBlockInMask(CurrentLink->getParent());
 
       // Non-FP RdxDescs will have all fast math flags set, so clear them.
       FastMathFlags FMFs = isa<FPMathOperator>(CurrentLinkI)
@@ -9889,7 +9648,8 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
     // different numbers of lanes. Partial reductions mask the input instead.
     if (!PhiR->isInLoop() && CM.foldTailByMasking() &&
         !isa<VPPartialReductionRecipe>(OrigExitingVPV->getDefiningRecipe())) {
-      VPValue *Cond = RecipeBuilder.getBlockInMask(OrigLoop->getHeader());
+      VPValue *Cond =
+          RecipeBuilder.getBlockInMask(VectorLoopRegion->getEntryBasicBlock());
       Type *PhiTy = PhiR->getOperand(0)->getLiveInIRValue()->getType();
       std::optional<FastMathFlags> FMFs =
           PhiTy->isFloatingPointTy()
diff --git a/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h b/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h
index caa18e263676b..74efe2da46f65 100644
--- a/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h
+++ b/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h
@@ -68,15 +68,7 @@ class VPRecipeBuilder {
 
   VPBuilder &Builder;
 
-  /// When we if-convert we need to create edge masks. We have to cache values
-  /// so that we don't end up with exponential recursion/IR. Note that
-  /// if-conversion currently takes place during VPlan-construction, so these
-  /// caches are only used at that stage.
-  using EdgeMaskCacheTy =
-      DenseMap<std::pair<BasicBlock *, BasicBlock *>, VPValue *>;
-  using BlockMaskCacheTy = DenseMap<BasicBlock *, VPValue *>;
-  EdgeMaskCacheTy EdgeMaskCache;
-  BlockMaskCacheTy BlockMaskCache;
+  DenseMap<VPBasicBlock *, VPValue *> &BlockMaskCache;
 
   // VPlan construction support: Hold a mapping from ingredients to
   // their recipe.
@@ -118,11 +110,6 @@ class VPRecipeBuilder {
   tryToOptimizeInductionTruncate(TruncInst *I, ArrayRef<VPValue *> Operands,
                                  VFRange &Range);
 
-  /// Handle non-loop phi nodes. Return a new VPBlendRecipe otherwise. Currently
-  /// all such phi nodes are turned into a sequence of select instructions as
-  /// the vectorizer currently performs full if-conversion.
-  VPBlendRecipe *tryToBlend(PHINode *Phi, ArrayRef<VPValue *> Operands);
-
   /// Handle call instructions. If \p CI can be widened for \p Range.Start,
   /// return a new VPWidenCallRecipe or VPWidenIntrinsicRecipe. Range.End may be
   /// decreased to ensure same decision from \p Range.Start to \p Range.End.
@@ -160,9 +147,11 @@ class VPRecipeBuilder {
                   LoopVectorizationLegality *Legal,
                   LoopVectorizationCostModel &CM,
                   PredicatedScalarEvolution &PSE, VPBuilder &Builder,
+                  DenseMap<VPBasicBlock *, VPValue *> &BlockMaskCache,
                   LoopVersioning *LVer)
       : Plan(Plan), OrigLoop(OrigLoop), TLI(TLI), TTI(TTI), Legal(Legal),
-        CM(CM), PSE(PSE), Builder(Builder), LVer(LVer) {}
+        CM(CM), PSE(PSE), Builder(Builder), BlockMaskCache(BlockMaskCache),
+        LVer(LVer) {}
 
   std::optional<unsigned> getScalingForReduction(const Instruction *ExitInst) {
     auto It = ScaledReductionMap.find(ExitInst);
@@ -193,27 +182,10 @@ class VPRecipeBuilder {
     Ingredient2Recipe[I] = R;
   }
 
-  /// Create the mask for the vector loop header block.
-  void createHeaderMask();
-
-  /// A helper function that computes the predicate of the block BB, assuming
-  /// that the header block of the loop is set to True or the loop mask when
-  /// tail folding.
-  void createBlockInMask(BasicBlock *BB);
-
   /// Returns the *entry* mask for the block \p BB.
-  VPValue *getBlockInMask(BasicBlock *BB) const;
-
-  /// Create an edge mask for every destination of cases and/or default.
-  void createSwitchEdgeMasks(SwitchInst *SI);
-
-  /// A helper function that computes the predicate of the edge between SRC
-  /// and DST.
-  VPValue *createEdgeMask(BasicBlock *Src, BasicBlock *Dst);
-
-  /// A helper that returns the previously computed predicate of the edge
-  /// between SRC and DST.
-  VPValue *getEdgeMask(BasicBlock *Src, BasicBlock *Dst) const;
+  VPValue *getBlockInMask(VPBasicBlock *BB) const {
+    return BlockMaskCache.lookup(BB);
+  }
 
   /// Return the recipe created for given ingredient.
   VPRecipeBase *getRecipe(Instruction *I) {
@@ -238,6 +210,13 @@ class VPRecipeBuilder {
     }
     return Plan.getOrAddLiveIn(V);
   }
+
+  void updateBlockMaskCache(VPValue *Old, VPValue *New) {
+    for (auto &[_, V] : BlockMaskCache) {
+      if (V == Old)
+        V = New;
+    }
+  }
 };
 } // end namespace llvm
 
diff --git a/llvm/lib/Transforms/Vectorize/VPlanConstruction.cpp b/llvm/lib/Transforms/Vectorize/VPlanConstruction.cpp
index b924b14035261..92bd49ace3638 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanConstruction.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanConstruction.cpp
@@ -66,8 +66,7 @@ class PlainCFGBuilder {
       : TheLoop(Lp), LI(LI), Plan(std::make_unique<VPlan>(Lp)) {}
 
   /// Build plain CFG for TheLoop  and connects it to Plan's entry.
-  std::unique_ptr<VPlan>
-  buildPlainCFG(DenseMap<VPBlockBase *, BasicBlock *> &VPB2IRBB);
+  std::unique_ptr<VPlan> buildPlainCFG();
 };
 } // anonymous namespace
 
@@ -242,8 +241,7 @@ void PlainCFGBuilder::createVPInstructionsForVPBB(VPBasicBlock *VPBB,
 }
 
 // Main interface to build the plain CFG.
-std::unique_ptr<VPlan> PlainCFGBuilder::buildPlainCFG(
-    DenseMap<VPBlockBase *, BasicBlock *> &VPB2IRBB) {
+std::unique_ptr<VPlan> PlainCFGBuilder::buildPlainCFG() {
   VPIRBasicBlock *Entry = cast<VPIRBasicBlock>(Plan->getEntry());
   BB2VPBB[Entry->getIRBasicBlock()] = Entry;
   for (VPIRBasicBlock *ExitVPBB : Plan->getExitBlocks())
@@ -334,18 +332,14 @@ std::unique_ptr<VPlan> PlainCFGBuilder::buildPlainCFG(
     }
   }
 
-  for (const auto &[IRBB, VPB] : BB2VPBB)
-    VPB2IRBB[VPB] = IRBB;
-
   LLVM_DEBUG(Plan->setName("Plain CFG\n"); dbgs() << *Plan);
   return std::move(Plan);
 }
 
-std::unique_ptr<VPlan> VPlanTransforms::buildPlainCFG(
-    Loop *TheLoop, LoopInfo &LI,
-    DenseMap<VPBlockBase *, BasicBlock *> &VPB2IRBB) {
+std::unique_ptr<VPlan> VPlanTransforms::buildPlainCFG(Loop *TheLoop,
+                                                      LoopInfo &LI) {
   PlainCFGBuilder Builder(TheLoop, &LI);
-  return Builder.buildPlainCFG(VPB2IRBB);
+  return Builder.buildPlainCFG();
 }
 
 /// Checks if \p HeaderVPB is a loop header block in the plain CFG; that is, it
diff --git a/llvm/lib/Transforms/Vectorize/VPlanPredicator.cpp b/llvm/lib/Transforms/Vectorize/VPlanPredicator.cpp
new file mode 100644
index 0000000000000..dda1f10b20c0a
--- /dev/null
+++ b/llvm/lib/Transforms/Vectorize/VPlanPredicator.cpp
@@ -0,0 +1,310 @@
+//===-- VPlanPredicator.cpp - VPlan predicator ----------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+///
+/// \file
+/// This file implements predication for VPlans.
+///
+//===----------------------------------------------------------------------===//
+
+#include "VPRecipeBuilder.h"
+#include "VPlan.h"
+#include "VPlanCFG.h"
+#include "VPlanTransforms.h"
+#include "VPlanUtils.h"
+#include "llvm/ADT/PostOrderIterator.h"
+
+using namespace llvm;
+
+namespace {
+struct VPPredicator {
+  using BlockMaskCacheTy = DenseMap<VPBasicBlock *, VPValue *>;
+  VPPredicator(BlockMaskCacheTy &BlockMaskCache)
+      : BlockMaskCache(BlockMaskCache) {}
+
+  /// Builder to construct recipes to compute masks.
+  VPBuilder Builder;
+
+  /// When we if-convert we need to create edge masks. We have to cache values
+  /// so that we don't end up with exponential recursion/IR.
+  using EdgeMaskCacheTy =
+      DenseMap<std::pair<const VPBasicBlock *, const VPBasicBlock *>,
+               VPValue *>;
+  EdgeMaskCacheTy EdgeMaskCache;
+
+  BlockMaskCacheTy &BlockMaskCache;
+
+  /// Returns the previously computed predicate of the edge between \p Src and
+  /// \p Dst.
+  VPValue *getEdgeMask(const VPBasicBlock *Src, const VPBasicBlock *Dst) const {
+    return EdgeMaskCache.lookup({Src, Dst});
+  }
+
+  /// Returns the *entry* mask for \p VPBB.
+  VPValue *getBlockInMask(VPBasicBlock *VPBB) const {
+    return BlockMaskCache.lookup(VPBB);
+  }
+  void setBlockInMask(VPBasicBlock *VPBB, VPValue *Mask) {
+    // TODO: Include the masks as operands in the predicated VPlan directly to
+    // remove the need to keep a map of masks beyond the predication transform.
+    assert(!BlockMaskCache.contains(VPBB) && "Mask already set");
+    BlockMaskCache[VPBB] = Mask;
+  }
+
+  /// Compute and return the mask for the vector loop header block.
+  void createHeaderMask(VPBasicBlock *HeaderVPBB, bool FoldTail);
+
+  /// Compute and return the predicate of \p VPBB, assuming that the header
+  /// block of the loop is set to True or the loop mask when tail folding.
+  VPValue *createBlockInMask(VPBasicBlock *VPBB);
+
+  /// Computes and return the predicate of the edge between \p Src and \p Dst.
+  VPValue *createEdgeMask(VPBasicBlock *Src, VPBasicBlock *Dst);
+
+  /// Create an edge mask for every destination of cases and/or default.
+  void createSwitchEdgeMasks(VPInstruction *SI);
+};
+} // namespace
+
+VPValue *VPPredicator::createEdgeMask(VPBasicBlock *Src, VPBasicBlock *Dst) {
+  assert(is_contained(Dst->getPredecessors(), Src) && "Invalid edge");
+
+  // Look for cached value.
+  VPValue *EdgeMask = getEdgeMask(Src, Dst);
+  if (EdgeMask)
+    return EdgeMask;
+
+  VPValue *SrcMask = getBlockInMask(Src);
+
+  // The terminator has to be a branch inst!
+  if (Src->empty() || Src->getNumSuccessors() == 1) {
+    EdgeMaskCache[{Src, Dst}] = SrcMask;
+    return SrcMask;
+  }
+
+  auto *Term = cast<VPInstruction>(Src->getTerminator());
+  if (Term->getOpcode() == Instruction::Switch) {
+    createSwitchEdgeMasks(Term);
+    return getEdgeMask(Src, Dst);
+  }
+
+  auto *BI = cast<VPInstruction>(Src->getTerminator());
+  assert(BI->getOpcode() == VPInstruction::BranchOnCond);
+  if (Src->getSuccessors()[0] == Src->getSuccessors()[1]) {
+    EdgeMaskCache[{Src, Dst}] = SrcMask;
+    return SrcMask;
+  }
+
+  EdgeMask = BI->getOperand(0);
+  assert(EdgeMask && "No Edge Mask found for condition");
+
+  if (Src->getSuccessors()[0] != Dst)
+    EdgeMask = Builder.createNot(EdgeMask, BI->getDebugLoc());
+
+  if (SrcMask) { // Otherwise block in-mask is all-one, no need to AND.
+    // The bitwise 'And' of SrcMask and EdgeMask introduces new UB if SrcMask
+    // is false and EdgeMask is poison. Avoid that by using 'LogicalAnd'
+    // instead which generates 'select i1 SrcMask, i1 EdgeMask, i1 false'.
+    EdgeMask = Builder.createLogicalAnd(SrcMask, EdgeMask, BI->getDebugLoc());
+  }
+
+  EdgeMaskCache[{Src, Dst}] = EdgeMask;
+  return EdgeMask;
+}
+
+VPValue *VPPredicator::createBlockInMask(VPBasicBlock *VPBB) {
+  Builder.setInsertPoint(VPBB, VPBB->begin());
+  // All-one mask is modelled as no-mask following the convention for masked
+  // load/store/gather/scatter. Initialize BlockMask to no-mask.
+  VPValue *BlockMask = nullptr;
+  // This is the block mask. We OR all unique incoming edges.
+  for (auto *Predecessor : SetVector<VPBlockBase *>(
+           VPBB->getPredecessors().begin(), VPBB->getPredecessors().end())) {
+    VPValue *EdgeMask = createEdgeMask(cast<VPBasicBlock>(Predecessor), VPBB);
+    if (!EdgeMask) { // Mask of predecessor is all-one so mask of block is
+                     // too.
+      setBlockInMask(VPBB, EdgeMask);
+      return EdgeMask;
+    }
+
+    if (!BlockMask) { // BlockMask has its initialized nullptr value.
+      BlockMask = EdgeMask;
+      continue;
+    }
+
+    BlockMask = Builder.createOr(BlockMask, EdgeMask, {});
+  }
+
+  setBlockInMask(VPBB, BlockMask);
+  return BlockMask;
+}
+
+void VPPredicator::createHeaderMask(VPBasicBlock *HeaderVPBB, bool FoldTail) {
+  if (!FoldTail) {
+    setBlockInMask(HeaderVPBB, nullptr);
+    return;
+  }
+
+  // Introduce the early-exit compare IV <= BTC to form header block mask.
+  // This is used instead of IV < TC because TC may wrap, unlike BTC. Start by
+  // constructing the desired canonical IV in the header block as its first
+  // non-phi instructions.
+
+  auto NewInsertionPoint = HeaderVPBB->getFirstNonPhi();
+  auto &Plan = *HeaderVPBB->getPlan();
+  auto *IV = new VPWidenCanonicalIVRecipe(Plan.getCanonicalIV());
+  HeaderVPBB->insert(IV, NewInsertionPoint);
+
+  VPBuilder::InsertPointGuard Guard(Builder);
+  Builder.setInsertPoint(HeaderVPBB, NewInsertionPoint);
+  VPValue *BlockMask = nullptr;
+  VPValue *BTC = Plan.getOrCreateBackedgeTakenCount();
+  BlockMask = Builder.createICmp(CmpInst::ICMP_ULE, IV, BTC);
+  setBlockInMask(HeaderVPBB, BlockMask);
+}
+
+void VPPredicator::createSwitchEdgeMasks(VPInstruction *SI) {
+  VPBasicBlock *Src = SI->getParent();
+
+  // Create masks where the terminator in Src is a switch. We create mask for
+  // all edges at the same time. This is more efficient, as we can create and
+  // collect compares for all cases once.
+  VPValue *Cond = SI->getOperand(0);
+  VPBasicBlock *DefaultDst = cast<VPBasicBlock>(Src->getSuccessors()[0]);
+  MapVector<VPBasicBlock *, SmallVector<VPValue *>> Dst2Compares;
+  for (const auto &[Idx, Succ] :
+       enumerate(ArrayRef(Src->getSuccessors()).drop_front())) {
+    VPBasicBlock *Dst = cast<VPBasicBlock>(Succ);
+    assert(!EdgeMaskCache.contains({Src, Dst}) && "Edge masks already created");
+    //  Cases whose destination is the same as default are redundant and can
+    //  be ignored - they will get there anyhow.
+    if (Dst == DefaultDst)
+      continue;
+    auto &Compares = Dst2Compares[Dst];
+    VPValue *V = SI->getOperand(Idx + 1);
+    Compares.push_back(Builder.createICmp(CmpInst::ICMP_EQ, Cond, V));
+  }
+
+  // We need to handle 2 separate cases below for all entries in Dst2Compares,
+  // which excludes destinations matching the default destination.
+  VPValue *SrcMask = getBlockInMask(Src);
+  VPValue *DefaultMask = nullptr;
+  for (const auto &[Dst, Conds] : Dst2Compares) {
+    // 1. Dst is not the default destination. Dst is reached if any of the
+    // cases with destination == Dst are taken. Join the conditions for each
+    // case whose destination == Dst using an OR.
+    VPValue *Mask = Conds[0];
+    for (VPValue *V : ArrayRef<VPValue *>(Conds).drop_front())
+      Mask = Builder.createOr(Mask, V);
+    if (SrcMask)
+      Mask = Builder.createLogicalAnd(SrcMask, Mask);
+    EdgeMaskCache[{Src, Dst}] = Mask;
+
+    // 2. Create the mask for the default destination, which is reached if
+    // none of the cases with destination != default destination are taken.
+    // Join the conditions for each case where the destination is != Dst using
+    // an OR and negate it.
+    DefaultMask = DefaultMask ? Builder.createOr(DefaultMask, Mask) : Mask;
+  }
+
+  if (DefaultMask) {
+    DefaultMask = Builder.createNot(DefaultMask);
+    if (SrcMask)
+      DefaultMask = Builder.createLogicalAnd(SrcMask, DefaultMask);
+  }
+  EdgeMaskCache[{Src, DefaultDst}] = DefaultMask;
+}
+
+void VPlanTransforms::predicateAndLinearize(
+    VPlan &Plan, bool FoldTail,
+    DenseMap<VPBasicBlock *, VPValue *> &BlockMaskCache) {
+  VPRegionBlock *LoopRegion = Plan.getVectorLoopRegion();
+  // Scan the body of the loop in a topological order to visit each basic block
+  // after having visited its predecessor basic blocks.
+  VPBasicBlock *Header = LoopRegion->getEntryBasicBlock();
+  ReversePostOrderTraversal<VPBlockShallowTraversalWrapper<VPBlockBase *>> RPOT(
+      Header);
+  VPPredicator Predicator(BlockMaskCache);
+  for (VPBasicBlock *VPBB : VPBlockUtils::blocksOnly<VPBasicBlock>(RPOT)) {
+    if (VPBB == Header) {
+      Predicator.createHeaderMask(Header, FoldTail);
+      continue;
+    }
+
+    SmallVector<VPWidenPHIRecipe *> Phis;
+    for (VPRecipeBase &R : VPBB->phis())
+      Phis.push_back(cast<VPWidenPHIRecipe>(&R));
+
+    Predicator.createBlockInMask(VPBB);
+
+    for (VPWidenPHIRecipe *Phi : Phis) {
+      PHINode *IRPhi = cast<PHINode>(Phi->getUnderlyingValue());
+
+      unsigned NumIncoming = IRPhi->getNumIncomingValues();
+
+      // We know that all PHIs in non-header blocks are converted into selects,
+      // so we don't have to worry about the insertion order and we can just use
+      // the builder. At this point we generate the predication tree. There may
+      // be duplications since this is a simple recursive scan, but future
+      // optimizations will clean it up.
+
+      // Map incoming IR BasicBlocks to incoming VPValues, for lookup below.
+      // TODO: Add operands and masks in order from the VPlan predecessors.
+      DenseMap<BasicBlock *, VPValue *> VPIncomingValues;
+      DenseMap<BasicBlock *, VPBasicBlock *> VPIncomingBlocks;
+      for (const auto &[Idx, Pred] :
+           enumerate(predecessors(IRPhi->getParent()))) {
+        VPIncomingValues[Pred] = Phi->getOperand(Idx);
+        VPIncomingBlocks[Pred] =
+            cast<VPBasicBlock>(VPBB->getPredecessors()[Idx]);
+      }
+
+      SmallVector<VPValue *, 2> OperandsWithMask;
+      for (unsigned In = 0; In < NumIncoming; In++) {
+        BasicBlock *Pred = IRPhi->getIncomingBlock(In);
+        OperandsWithMask.push_back(VPIncomingValues.lookup(Pred));
+        VPValue *EdgeMask =
+            Predicator.getEdgeMask(VPIncomingBlocks.lookup(Pred), VPBB);
+        if (!EdgeMask) {
+          assert(In == 0 && "Both null and non-null edge masks found");
+          assert(all_equal(Phi->operands()) &&
+                 "Distinct incoming values with one having a full mask");
+          break;
+        }
+        OperandsWithMask.push_back(EdgeMask);
+      }
+      auto *Blend = new VPBlendRecipe(IRPhi, OperandsWithMask);
+      Blend->insertBefore(Phi);
+      Phi->replaceAllUsesWith(Blend);
+      Phi->eraseFromParent();
+    }
+  }
+
+  VPBlockBase *PrevVPBB = nullptr;
+  for (VPBasicBlock *VPBB : VPBlockUtils::blocksOnly<VPBasicBlock>(RPOT)) {
+    // Handle VPBBs down to the latch.
+    if (PrevVPBB && VPBB == LoopRegion->getExiting()) {
+      VPBlockUtils::connectBlocks(PrevVPBB, VPBB);
+      break;
+    }
+
+    auto Successors = to_vector(VPBB->getSuccessors());
+    if (Successors.size() > 1)
+      VPBB->getTerminator()->eraseFromParent();
+
+    // Flatten the CFG in the loop. Masks for blocks have already been
+    // generated and added to recipes as needed. To do so, first disconnect
+    // VPBB from its successors. Then connect VPBB to the previously visited
+    // VPBB.
+    for (auto *Succ : Successors)
+      VPBlockUtils::disconnectBlocks(VPBB, Succ);
+    if (PrevVPBB)
+      VPBlockUtils::connectBlocks(PrevVPBB, VPBB);
+
+    PrevVPBB = VPBB;
+  }
+}
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.h b/llvm/lib/Transforms/Vectorize/VPlanTransforms.h
index 530e06d983e23..4cc0132574e1c 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.h
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.h
@@ -53,9 +53,7 @@ struct VPlanTransforms {
       verifyVPlanIsValid(Plan);
   }
 
-  static std::unique_ptr<VPlan>
-  buildPlainCFG(Loop *TheLoop, LoopInfo &LI,
-                DenseMap<VPBlockBase *, BasicBlock *> &VPB2IRBB);
+  static std::unique_ptr<VPlan> buildPlainCFG(Loop *TheLoop, LoopInfo &LI);
 
   /// Prepare the plan for vectorization. It will introduce a dedicated
   /// VPBasicBlock for the vector pre-header as well as a VPBasicBlock as exit
@@ -217,6 +215,16 @@ struct VPlanTransforms {
   /// candidates.
   static void narrowInterleaveGroups(VPlan &Plan, ElementCount VF,
                                      unsigned VectorRegWidth);
+
+  /// Predicate and linearize the control-flow in the top-level loop region of
+  /// \p Plan. If \p FoldTail is true, also create a mask guarding the loop
+  /// header, otherwise use all-true for the header mask. Masks for blocks are
+  /// added to \p BlockMaskCache, which in turn is temporarily used for wide
+  /// recipe construction. This argument is temporary and will be removed in the
+  /// future.
+  static void
+  predicateAndLinearize(VPlan &Plan, bool FoldTail,
+                        DenseMap<VPBasicBlock *, VPValue *> &BlockMaskCache);
 };
 
 } // namespace llvm
diff --git a/llvm/unittests/Transforms/Vectorize/VPlanTestBase.h b/llvm/unittests/Transforms/Vectorize/VPlanTestBase.h
index 15e21972840f6..e2ad65b93e3dd 100644
--- a/llvm/unittests/Transforms/Vectorize/VPlanTestBase.h
+++ b/llvm/unittests/Transforms/Vectorize/VPlanTestBase.h
@@ -71,8 +71,7 @@ class VPlanTestIRBase : public testing::Test {
 
     Loop *L = LI->getLoopFor(LoopHeader);
     PredicatedScalarEvolution PSE(*SE, *L);
-    DenseMap<VPBlockBase *, BasicBlock *> VPB2IRBB;
-    auto Plan = VPlanTransforms::buildPlainCFG(L, *LI, VPB2IRBB);
+    auto Plan = VPlanTransforms::buildPlainCFG(L, *LI);
     VFRange R(ElementCount::getFixed(1), ElementCount::getFixed(2));
     VPlanTransforms::prepareForVectorization(*Plan, IntegerType::get(*Ctx, 64),
                                              PSE, true, false, L, {}, false, R);



More information about the llvm-commits mailing list