[llvm] [VPlan] Manage created blocks directly in VPlan. (NFC) (PR #120918)

Florian Hahn via llvm-commits llvm-commits at lists.llvm.org
Mon Dec 30 02:47:41 PST 2024


https://github.com/fhahn updated https://github.com/llvm/llvm-project/pull/120918

>From dd45cad20284589bbe26db0c64bf8d1ad3210e91 Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Wed, 18 Dec 2024 14:33:49 +0000
Subject: [PATCH 1/6] [VPlan] Manage created blocks directly in VPlan. (NFC)

This patch changes the way blocks are managed by VPlan. Previously all
blocks reachable from entry would be cleaned up when a VPlan is
destroyed. With this patch, each VPlan keeps track of blocks created for
it in a list and this list is then used to delete all blocks in the list
when the VPlan is destroyed. To do so, block creation is funneled
through helpers in directly in VPlan.

The main advantage of doing so is it simplifies CFG transformations, as
those do not have to take care of deleting any blocks, just adjusting
the CFG. This helps to simplify
https://github.com/llvm/llvm-project/pull/108378 and
https://github.com/llvm/llvm-project/pull/106748.

This also simplifies handling of 'immutable' blocks a VPlan holds
references to, which at the moment only include the scalar
header block.

Note that the original constructors taking VPBlockBase are retained at
the moment for unit tests.
---
 .../Transforms/Vectorize/LoopVectorize.cpp    |  7 +-
 llvm/lib/Transforms/Vectorize/VPlan.cpp       | 84 ++++++++++++-------
 llvm/lib/Transforms/Vectorize/VPlan.h         | 54 ++++++------
 .../Transforms/Vectorize/VPlanHCFGBuilder.cpp |  4 +-
 .../Transforms/Vectorize/VPlanTransforms.cpp  | 19 +++--
 5 files changed, 97 insertions(+), 71 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 355ff40ce770e7..4f2bf097d0cde5 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -2477,7 +2477,7 @@ static void introduceCheckBlockInVPlan(VPlan &Plan, BasicBlock *CheckIRBB) {
     assert(PreVectorPH->getNumSuccessors() == 2 && "Expected 2 successors");
     assert(PreVectorPH->getSuccessors()[0] == ScalarPH &&
            "Unexpected successor");
-    VPIRBasicBlock *CheckVPIRBB = VPIRBasicBlock::fromBasicBlock(CheckIRBB);
+    VPIRBasicBlock *CheckVPIRBB = Plan.createVPIRBasicBlock(CheckIRBB);
     VPBlockUtils::insertOnEdge(PreVectorPH, VectorPH, CheckVPIRBB);
     PreVectorPH = CheckVPIRBB;
   }
@@ -8189,11 +8189,10 @@ EpilogueVectorizerEpilogueLoop::emitMinimumVectorEpilogueIterCountCheck(
 
   // A new entry block has been created for the epilogue VPlan. Hook it in, as
   // otherwise we would try to modify the entry to the main vector loop.
-  VPIRBasicBlock *NewEntry = VPIRBasicBlock::fromBasicBlock(Insert);
+  VPIRBasicBlock *NewEntry = Plan.createVPIRBasicBlock(Insert);
   VPBasicBlock *OldEntry = Plan.getEntry();
   VPBlockUtils::reassociateBlocks(OldEntry, NewEntry);
   Plan.setEntry(NewEntry);
-  delete OldEntry;
 
   introduceCheckBlockInVPlan(Plan, Insert);
   return Insert;
@@ -9463,7 +9462,7 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range) {
         VPBB->appendRecipe(Recipe);
     }
 
-    VPBlockUtils::insertBlockAfter(new VPBasicBlock(), VPBB);
+    VPBlockUtils::insertBlockAfter(Plan->createVPBasicBlock(""), VPBB);
     VPBB = cast<VPBasicBlock>(VPBB->getSingleSuccessor());
   }
 
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.cpp b/llvm/lib/Transforms/Vectorize/VPlan.cpp
index 9a082921d4f7f2..e03847cea131f0 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlan.cpp
@@ -205,11 +205,6 @@ VPBlockBase *VPBlockBase::getEnclosingBlockWithPredecessors() {
   return Parent->getEnclosingBlockWithPredecessors();
 }
 
-void VPBlockBase::deleteCFG(VPBlockBase *Entry) {
-  for (VPBlockBase *Block : to_vector(vp_depth_first_shallow(Entry)))
-    delete Block;
-}
-
 VPBasicBlock::iterator VPBasicBlock::getFirstNonPhi() {
   iterator It = begin();
   while (It != end() && It->isPhi())
@@ -474,6 +469,16 @@ void VPIRBasicBlock::execute(VPTransformState *State) {
   connectToPredecessors(State->CFG);
 }
 
+VPIRBasicBlock *VPIRBasicBlock::clone() {
+  auto *NewBlock = getPlan()->createVPIRBasicBlock(IRBB);
+  for (VPRecipeBase &R : make_early_inc_range(*NewBlock))
+    R.eraseFromParent();
+
+  for (VPRecipeBase &R : Recipes)
+    NewBlock->appendRecipe(R.clone());
+  return NewBlock;
+}
+
 void VPBasicBlock::execute(VPTransformState *State) {
   bool Replica = bool(State->Lane);
   BasicBlock *NewBB = State->CFG.PrevBB; // Reuse it if possible.
@@ -523,6 +528,13 @@ void VPBasicBlock::dropAllReferences(VPValue *NewValue) {
   }
 }
 
+VPBasicBlock *VPBasicBlock::clone() {
+  auto *NewBlock = getPlan()->createVPBasicBlock(getName());
+  for (VPRecipeBase &R : *this)
+    NewBlock->appendRecipe(R.clone());
+  return NewBlock;
+}
+
 void VPBasicBlock::executeRecipes(VPTransformState *State, BasicBlock *BB) {
   LLVM_DEBUG(dbgs() << "LV: vectorizing VPBB:" << getName()
                     << " in BB:" << BB->getName() << '\n');
@@ -541,7 +553,7 @@ VPBasicBlock *VPBasicBlock::splitAt(iterator SplitAt) {
 
   SmallVector<VPBlockBase *, 2> Succs(successors());
   // Create new empty block after the block to split.
-  auto *SplitBlock = new VPBasicBlock(getName() + ".split");
+  auto *SplitBlock = getPlan()->createVPBasicBlock(getName() + ".split");
   VPBlockUtils::insertBlockAfter(SplitBlock, this);
 
   // Finally, move the recipes starting at SplitAt to new block.
@@ -701,8 +713,8 @@ static std::pair<VPBlockBase *, VPBlockBase *> cloneFrom(VPBlockBase *Entry) {
 
 VPRegionBlock *VPRegionBlock::clone() {
   const auto &[NewEntry, NewExiting] = cloneFrom(getEntry());
-  auto *NewRegion =
-      new VPRegionBlock(NewEntry, NewExiting, getName(), isReplicator());
+  auto *NewRegion = getPlan()->createVPRegionBlock(NewEntry, NewExiting,
+                                                   getName(), isReplicator());
   for (VPBlockBase *Block : vp_depth_first_shallow(NewEntry))
     Block->setParent(NewRegion);
   return NewRegion;
@@ -822,17 +834,20 @@ void VPRegionBlock::print(raw_ostream &O, const Twine &Indent,
 #endif
 
 VPlan::VPlan(Loop *L) {
-  setEntry(VPIRBasicBlock::fromBasicBlock(L->getLoopPreheader()));
-  ScalarHeader = VPIRBasicBlock::fromBasicBlock(L->getHeader());
+  setEntry(createVPIRBasicBlock(L->getLoopPreheader()));
+  ScalarHeader = createVPIRBasicBlock(L->getHeader());
 }
 
 VPlan::~VPlan() {
   if (Entry) {
     VPValue DummyValue;
-    for (VPBlockBase *Block : vp_depth_first_shallow(Entry))
-      Block->dropAllReferences(&DummyValue);
 
-    VPBlockBase::deleteCFG(Entry);
+    for (auto *VPB : reverse(CreatedBlocks))
+      VPB->dropAllReferences(&DummyValue);
+
+    for (auto *VPB : reverse(CreatedBlocks)) {
+      delete VPB;
+    }
   }
   for (VPValue *VPV : VPLiveInsToFree)
     delete VPV;
@@ -840,14 +855,6 @@ VPlan::~VPlan() {
     delete BackedgeTakenCount;
 }
 
-VPIRBasicBlock *VPIRBasicBlock::fromBasicBlock(BasicBlock *IRBB) {
-  auto *VPIRBB = new VPIRBasicBlock(IRBB);
-  for (Instruction &I :
-       make_range(IRBB->begin(), IRBB->getTerminator()->getIterator()))
-    VPIRBB->appendRecipe(new VPIRInstruction(I));
-  return VPIRBB;
-}
-
 VPlanPtr VPlan::createInitialVPlan(Type *InductionTy,
                                    PredicatedScalarEvolution &PSE,
                                    bool RequiresScalarEpilogueCheck,
@@ -861,7 +868,7 @@ VPlanPtr VPlan::createInitialVPlan(Type *InductionTy,
   // an epilogue vector loop, the original entry block here will be replaced by
   // a new VPIRBasicBlock wrapping the entry to the epilogue vector loop after
   // generating code for the main vector loop.
-  VPBasicBlock *VecPreheader = new VPBasicBlock("vector.ph");
+  VPBasicBlock *VecPreheader = Plan->createVPBasicBlock("vector.ph");
   VPBlockUtils::connectBlocks(Plan->getEntry(), VecPreheader);
 
   // Create SCEV and VPValue for the trip count.
@@ -878,17 +885,17 @@ VPlanPtr VPlan::createInitialVPlan(Type *InductionTy,
 
   // Create VPRegionBlock, with empty header and latch blocks, to be filled
   // during processing later.
-  VPBasicBlock *HeaderVPBB = new VPBasicBlock("vector.body");
-  VPBasicBlock *LatchVPBB = new VPBasicBlock("vector.latch");
+  VPBasicBlock *HeaderVPBB = Plan->createVPBasicBlock("vector.body");
+  VPBasicBlock *LatchVPBB = Plan->createVPBasicBlock("vector.latch");
   VPBlockUtils::insertBlockAfter(LatchVPBB, HeaderVPBB);
-  auto *TopRegion = new VPRegionBlock(HeaderVPBB, LatchVPBB, "vector loop",
-                                      false /*isReplicator*/);
+  auto *TopRegion = Plan->createVPRegionBlock(
+      HeaderVPBB, LatchVPBB, "vector loop", false /*isReplicator*/);
 
   VPBlockUtils::insertBlockAfter(TopRegion, VecPreheader);
-  VPBasicBlock *MiddleVPBB = new VPBasicBlock("middle.block");
+  VPBasicBlock *MiddleVPBB = Plan->createVPBasicBlock("middle.block");
   VPBlockUtils::insertBlockAfter(MiddleVPBB, TopRegion);
 
-  VPBasicBlock *ScalarPH = new VPBasicBlock("scalar.ph");
+  VPBasicBlock *ScalarPH = Plan->createVPBasicBlock("scalar.ph");
   VPBlockUtils::connectBlocks(ScalarPH, ScalarHeader);
   if (!RequiresScalarEpilogueCheck) {
     VPBlockUtils::connectBlocks(MiddleVPBB, ScalarPH);
@@ -904,7 +911,7 @@ VPlanPtr VPlan::createInitialVPlan(Type *InductionTy,
   //    we unconditionally branch to the scalar preheader.  Do nothing.
   // 3) Otherwise, construct a runtime check.
   BasicBlock *IRExitBlock = TheLoop->getUniqueLatchExitBlock();
-  auto *VPExitBlock = VPIRBasicBlock::fromBasicBlock(IRExitBlock);
+  auto *VPExitBlock = Plan->createVPIRBasicBlock(IRExitBlock);
   // The connection order corresponds to the operands of the conditional branch.
   VPBlockUtils::insertBlockAfter(VPExitBlock, MiddleVPBB);
   VPBlockUtils::connectBlocks(MiddleVPBB, ScalarPH);
@@ -960,15 +967,13 @@ void VPlan::prepareToExecute(Value *TripCountV, Value *VectorTripCountV,
 /// have a single predecessor, which is rewired to the new VPIRBasicBlock. All
 /// successors of VPBB, if any, are rewired to the new VPIRBasicBlock.
 static void replaceVPBBWithIRVPBB(VPBasicBlock *VPBB, BasicBlock *IRBB) {
-  VPIRBasicBlock *IRVPBB = VPIRBasicBlock::fromBasicBlock(IRBB);
+  VPIRBasicBlock *IRVPBB = VPBB->getPlan()->createVPIRBasicBlock(IRBB);
   for (auto &R : make_early_inc_range(*VPBB)) {
     assert(!R.isPhi() && "Tried to move phi recipe to end of block");
     R.moveBefore(*IRVPBB, IRVPBB->end());
   }
 
   VPBlockUtils::reassociateBlocks(VPBB, IRVPBB);
-
-  delete VPBB;
 }
 
 /// Generate the code inside the preheader and body of the vectorized loop.
@@ -1217,6 +1222,7 @@ static void remapOperands(VPBlockBase *Entry, VPBlockBase *NewEntry,
 }
 
 VPlan *VPlan::duplicate() {
+  unsigned CreatedBlockSize = CreatedBlocks.size();
   // Clone blocks.
   const auto &[NewEntry, __] = cloneFrom(Entry);
 
@@ -1257,9 +1263,23 @@ VPlan *VPlan::duplicate() {
   assert(Old2NewVPValues.contains(TripCount) &&
          "TripCount must have been added to Old2NewVPValues");
   NewPlan->TripCount = Old2NewVPValues[TripCount];
+
+  for (unsigned I = CreatedBlockSize; I != CreatedBlocks.size(); ++I)
+    NewPlan->CreatedBlocks.push_back(CreatedBlocks[I]);
+  CreatedBlocks.truncate(CreatedBlockSize);
+
   return NewPlan;
 }
 
+VPIRBasicBlock *VPlan::createVPIRBasicBlock(BasicBlock *IRBB) {
+  auto *VPIRBB = new VPIRBasicBlock(IRBB);
+  for (Instruction &I :
+       make_range(IRBB->begin(), IRBB->getTerminator()->getIterator()))
+    VPIRBB->appendRecipe(new VPIRInstruction(I));
+  CreatedBlocks.push_back(VPIRBB);
+  return VPIRBB;
+}
+
 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
 
 Twine VPlanPrinter::getUID(const VPBlockBase *Block) {
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index e2c0ff79546758..eb0e3baa8d4f38 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -636,9 +636,6 @@ class VPBlockBase {
   /// Return the cost of the block.
   virtual InstructionCost cost(ElementCount VF, VPCostContext &Ctx) = 0;
 
-  /// Delete all blocks reachable from a given VPBlockBase, inclusive.
-  static void deleteCFG(VPBlockBase *Entry);
-
   /// Return true if it is legal to hoist instructions into this block.
   bool isLegalToHoistInto() {
     // There are currently no constraints that prevent an instruction to be
@@ -3638,12 +3635,7 @@ class VPBasicBlock : public VPBlockBase {
 
   /// Clone the current block and it's recipes, without updating the operands of
   /// the cloned recipes.
-  VPBasicBlock *clone() override {
-    auto *NewBlock = new VPBasicBlock(getName());
-    for (VPRecipeBase &R : *this)
-      NewBlock->appendRecipe(R.clone());
-    return NewBlock;
-  }
+  VPBasicBlock *clone() override;
 
 protected:
   /// Execute the recipes in the IR basic block \p BB.
@@ -3679,20 +3671,11 @@ class VPIRBasicBlock : public VPBasicBlock {
     return V->getVPBlockID() == VPBlockBase::VPIRBasicBlockSC;
   }
 
-  /// Create a VPIRBasicBlock from \p IRBB containing VPIRInstructions for all
-  /// instructions in \p IRBB, except its terminator which is managed in VPlan.
-  static VPIRBasicBlock *fromBasicBlock(BasicBlock *IRBB);
-
   /// The method which generates the output IR instructions that correspond to
   /// this VPBasicBlock, thereby "executing" the VPlan.
   void execute(VPTransformState *State) override;
 
-  VPIRBasicBlock *clone() override {
-    auto *NewBlock = new VPIRBasicBlock(IRBB);
-    for (VPRecipeBase &R : Recipes)
-      NewBlock->appendRecipe(R.clone());
-    return NewBlock;
-  }
+  VPIRBasicBlock *clone() override;
 
   BasicBlock *getIRBasicBlock() const { return IRBB; }
 };
@@ -3732,11 +3715,6 @@ class VPRegionBlock : public VPBlockBase {
         IsReplicator(IsReplicator) {}
 
   ~VPRegionBlock() override {
-    if (Entry) {
-      VPValue DummyValue;
-      Entry->dropAllReferences(&DummyValue);
-      deleteCFG(Entry);
-    }
   }
 
   /// Method to support type inquiry through isa, cast, and dyn_cast.
@@ -3863,6 +3841,8 @@ class VPlan {
   /// been modeled in VPlan directly.
   DenseMap<const SCEV *, VPValue *> SCEVToExpansion;
 
+  SmallVector<VPBlockBase *> CreatedBlocks;
+
 public:
   /// Construct a VPlan with \p Entry to the plan and with \p ScalarHeader
   /// wrapping the original header of the scalar loop.
@@ -4079,6 +4059,32 @@ class VPlan {
   /// Clone the current VPlan, update all VPValues of the new VPlan and cloned
   /// recipes to refer to the clones, and return it.
   VPlan *duplicate();
+
+  VPBasicBlock *createVPBasicBlock(const Twine &Name,
+                                   VPRecipeBase *Recipe = nullptr) {
+    auto *VPB = new VPBasicBlock(Name, Recipe);
+    CreatedBlocks.push_back(VPB);
+    return VPB;
+  }
+
+  VPRegionBlock *createVPRegionBlock(VPBlockBase *Entry, VPBlockBase *Exiting,
+                                     const std::string &Name = "",
+                                     bool IsReplicator = false) {
+    auto *VPB = new VPRegionBlock(Entry, Exiting, Name, IsReplicator);
+    CreatedBlocks.push_back(VPB);
+    return VPB;
+  }
+
+  VPRegionBlock *createVPRegionBlock(const std::string &Name = "",
+                                     bool IsReplicator = false) {
+    auto *VPB = new VPRegionBlock(Name, IsReplicator);
+    CreatedBlocks.push_back(VPB);
+    return VPB;
+  }
+
+  /// Create a VPIRBasicBlock from \p IRBB containing VPIRInstructions for all
+  /// instructions in \p IRBB, except its terminator which is managed in VPlan.
+  VPIRBasicBlock *createVPIRBasicBlock(BasicBlock *IRBB);
 };
 
 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
diff --git a/llvm/lib/Transforms/Vectorize/VPlanHCFGBuilder.cpp b/llvm/lib/Transforms/Vectorize/VPlanHCFGBuilder.cpp
index 6e633739fcc3dd..02f4c8d8872d82 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanHCFGBuilder.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanHCFGBuilder.cpp
@@ -182,7 +182,7 @@ VPBasicBlock *PlainCFGBuilder::getOrCreateVPBB(BasicBlock *BB) {
   // Create new VPBB.
   StringRef Name = isHeaderBB(BB, TheLoop) ? "vector.body" : BB->getName();
   LLVM_DEBUG(dbgs() << "Creating VPBasicBlock for " << Name << "\n");
-  VPBasicBlock *VPBB = new VPBasicBlock(Name);
+  VPBasicBlock *VPBB = Plan.createVPBasicBlock(Name);
   BB2VPBB[BB] = VPBB;
 
   // Get or create a region for the loop containing BB.
@@ -204,7 +204,7 @@ VPBasicBlock *PlainCFGBuilder::getOrCreateVPBB(BasicBlock *BB) {
   if (LoopOfBB == TheLoop) {
     RegionOfVPBB = Plan.getVectorLoopRegion();
   } else {
-    RegionOfVPBB = new VPRegionBlock(Name.str(), false /*isReplicator*/);
+    RegionOfVPBB = Plan.createVPRegionBlock(Name.str(), false /*isReplicator*/);
     RegionOfVPBB->setParent(Loop2Region[LoopOfBB->getParentLoop()]);
   }
   RegionOfVPBB->setEntry(VPBB);
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index 0b809c2b34df9e..cd3ea561e6aac0 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -297,8 +297,6 @@ static bool mergeReplicateRegionsIntoSuccessors(VPlan &Plan) {
     DeletedRegions.insert(Region1);
   }
 
-  for (VPRegionBlock *ToDelete : DeletedRegions)
-    delete ToDelete;
   return !DeletedRegions.empty();
 }
 
@@ -310,7 +308,8 @@ static VPRegionBlock *createReplicateRegion(VPReplicateRecipe *PredRecipe,
   assert(Instr->getParent() && "Predicated instruction not in any basic block");
   auto *BlockInMask = PredRecipe->getMask();
   auto *BOMRecipe = new VPBranchOnMaskRecipe(BlockInMask);
-  auto *Entry = new VPBasicBlock(Twine(RegionName) + ".entry", BOMRecipe);
+  auto *Entry =
+      Plan.createVPBasicBlock(Twine(RegionName) + ".entry", BOMRecipe);
 
   // Replace predicated replicate recipe with a replicate recipe without a
   // mask but in the replicate region.
@@ -318,7 +317,8 @@ static VPRegionBlock *createReplicateRegion(VPReplicateRecipe *PredRecipe,
       PredRecipe->getUnderlyingInstr(),
       make_range(PredRecipe->op_begin(), std::prev(PredRecipe->op_end())),
       PredRecipe->isUniform());
-  auto *Pred = new VPBasicBlock(Twine(RegionName) + ".if", RecipeWithoutMask);
+  auto *Pred =
+      Plan.createVPBasicBlock(Twine(RegionName) + ".if", RecipeWithoutMask);
 
   VPPredInstPHIRecipe *PHIRecipe = nullptr;
   if (PredRecipe->getNumUsers() != 0) {
@@ -328,8 +328,10 @@ static VPRegionBlock *createReplicateRegion(VPReplicateRecipe *PredRecipe,
     PHIRecipe->setOperand(0, RecipeWithoutMask);
   }
   PredRecipe->eraseFromParent();
-  auto *Exiting = new VPBasicBlock(Twine(RegionName) + ".continue", PHIRecipe);
-  VPRegionBlock *Region = new VPRegionBlock(Entry, Exiting, RegionName, true);
+  auto *Exiting =
+      Plan.createVPBasicBlock(Twine(RegionName) + ".continue", PHIRecipe);
+  VPRegionBlock *Region =
+      Plan.createVPRegionBlock(Entry, Exiting, RegionName, true);
 
   // Note: first set Entry as region entry and then connect successors starting
   // from it in order, to propagate the "parent" of each VPBasicBlock.
@@ -396,7 +398,6 @@ static bool mergeBlocksIntoPredecessors(VPlan &Plan) {
       VPBlockUtils::disconnectBlocks(VPBB, Succ);
       VPBlockUtils::connectBlocks(PredVPBB, Succ);
     }
-    delete VPBB;
   }
   return !WorkList.empty();
 }
@@ -1898,7 +1899,7 @@ void VPlanTransforms::handleUncountableEarlyExit(
   if (OrigLoop->getUniqueExitBlock()) {
     VPEarlyExitBlock = cast<VPIRBasicBlock>(MiddleVPBB->getSuccessors()[0]);
   } else {
-    VPEarlyExitBlock = VPIRBasicBlock::fromBasicBlock(
+    VPEarlyExitBlock = Plan.createVPIRBasicBlock(
         !OrigLoop->contains(TrueSucc) ? TrueSucc : FalseSucc);
   }
 
@@ -1908,7 +1909,7 @@ void VPlanTransforms::handleUncountableEarlyExit(
   IsEarlyExitTaken =
       Builder.createNaryOp(VPInstruction::AnyOf, {EarlyExitTakenCond});
 
-  VPBasicBlock *NewMiddle = new VPBasicBlock("middle.split");
+  VPBasicBlock *NewMiddle = Plan.createVPBasicBlock("middle.split");
   VPBlockUtils::insertOnEdge(LoopRegion, MiddleVPBB, NewMiddle);
   VPBlockUtils::connectBlocks(NewMiddle, VPEarlyExitBlock);
   NewMiddle->swapSuccessors();

>From e72a71fa3c31872b72a84cd78ac8b9b1aa719883 Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Tue, 24 Dec 2024 21:51:48 +0000
Subject: [PATCH 2/6] !fixup address comments, add comments

---
 llvm/lib/Transforms/Vectorize/VPlan.cpp |  3 ++-
 llvm/lib/Transforms/Vectorize/VPlan.h   | 17 +++++++++++++++--
 2 files changed, 17 insertions(+), 3 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/VPlan.cpp b/llvm/lib/Transforms/Vectorize/VPlan.cpp
index e03847cea131f0..204a1e01b9313c 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlan.cpp
@@ -1264,7 +1264,8 @@ VPlan *VPlan::duplicate() {
          "TripCount must have been added to Old2NewVPValues");
   NewPlan->TripCount = Old2NewVPValues[TripCount];
 
-  for (unsigned I = CreatedBlockSize; I != CreatedBlocks.size(); ++I)
+  // Transfer cloned blocks to new VPlan.
+  for (unsigned I : seq<unsigned>(CreatedBlockSize, CreatedBlocks.size()))
     NewPlan->CreatedBlocks.push_back(CreatedBlocks[I]);
   CreatedBlocks.truncate(CreatedBlockSize);
 
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index eb0e3baa8d4f38..434b4d7e49ab82 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -3714,8 +3714,7 @@ class VPRegionBlock : public VPBlockBase {
       : VPBlockBase(VPRegionBlockSC, Name), Entry(nullptr), Exiting(nullptr),
         IsReplicator(IsReplicator) {}
 
-  ~VPRegionBlock() override {
-  }
+  ~VPRegionBlock() override {}
 
   /// Method to support type inquiry through isa, cast, and dyn_cast.
   static inline bool classof(const VPBlockBase *V) {
@@ -3841,6 +3840,8 @@ class VPlan {
   /// been modeled in VPlan directly.
   DenseMap<const SCEV *, VPValue *> SCEVToExpansion;
 
+  /// Blocks allocated and owned by the VPlan. They will be deleted once the
+  /// VPlan is destroyed.
   SmallVector<VPBlockBase *> CreatedBlocks;
 
 public:
@@ -4060,6 +4061,9 @@ class VPlan {
   /// recipes to refer to the clones, and return it.
   VPlan *duplicate();
 
+  /// Create a new VPBasicBlock with \p Name and containing \p Recipe if
+  /// present. The returned block is owned by the VPlan and deleted once the
+  /// VPlan is destroyed.
   VPBasicBlock *createVPBasicBlock(const Twine &Name,
                                    VPRecipeBase *Recipe = nullptr) {
     auto *VPB = new VPBasicBlock(Name, Recipe);
@@ -4067,6 +4071,9 @@ class VPlan {
     return VPB;
   }
 
+  /// Create a new VPRegionBlock with \p Entry, \p Exiting and \p Name. If \p
+  /// IsReplicator is true, the region is a replicate region. The returned block
+  /// is owned by the VPlan and deleted once the VPlan is destroyed.
   VPRegionBlock *createVPRegionBlock(VPBlockBase *Entry, VPBlockBase *Exiting,
                                      const std::string &Name = "",
                                      bool IsReplicator = false) {
@@ -4075,6 +4082,10 @@ class VPlan {
     return VPB;
   }
 
+  /// Create a new VPRegionBlock with \p Name and entry and exiting blocks set
+  /// to nullptr. If \p IsReplicator is true, the region is a replicate region.
+  /// The returned block is owned by the VPlan and deleted once the VPlan is
+  /// destroyed.
   VPRegionBlock *createVPRegionBlock(const std::string &Name = "",
                                      bool IsReplicator = false) {
     auto *VPB = new VPRegionBlock(Name, IsReplicator);
@@ -4084,6 +4095,8 @@ class VPlan {
 
   /// Create a VPIRBasicBlock from \p IRBB containing VPIRInstructions for all
   /// instructions in \p IRBB, except its terminator which is managed in VPlan.
+  /// The returned block is owned by the VPlan and deleted once the VPlan is
+  /// destroyed.
   VPIRBasicBlock *createVPIRBasicBlock(BasicBlock *IRBB);
 };
 

>From 407dbc1eccf9e4cd22e1c20f9212fef0b2ead7f4 Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Fri, 27 Dec 2024 11:25:59 +0000
Subject: [PATCH 3/6] [VPlan] Funnel

---
 llvm/lib/Transforms/Vectorize/VPlan.h         |  14 +-
 .../Transforms/Vectorize/VPlanHCFGBuilder.h   |   4 +-
 .../Transforms/Vectorize/VPDomTreeTest.cpp    |  42 ++--
 .../Transforms/Vectorize/VPlanHCFGTest.cpp    |   2 +-
 .../Transforms/Vectorize/VPlanSlpTest.cpp     |   2 +-
 .../Transforms/Vectorize/VPlanTest.cpp        | 210 ++++++------------
 .../Transforms/Vectorize/VPlanTestBase.h      |  20 +-
 .../Vectorize/VPlanVerifierTest.cpp           |  89 +++-----
 8 files changed, 139 insertions(+), 244 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index 434b4d7e49ab82..beabcc7fd41874 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -3854,18 +3854,18 @@ class VPlan {
            "scalar header must be a leaf node");
   }
 
-  /// Construct a VPlan with \p Entry entering the plan, trip count \p TC and
-  /// with \p ScalarHeader wrapping the original header of the scalar loop.
-  VPlan(VPBasicBlock *Entry, VPValue *TC, VPIRBasicBlock *ScalarHeader)
-      : VPlan(Entry, ScalarHeader) {
-    TripCount = TC;
-  }
-
+public:
   /// Construct a VPlan for \p L. This will create VPIRBasicBlocks wrapping the
   /// original preheader and scalar header of \p L, to be used as entry and
   /// scalar header blocks of the new VPlan.
   VPlan(Loop *L);
 
+  VPlan(BasicBlock *ScalarHeaderBB, VPValue *TC) {
+    setEntry(new VPBasicBlock("preheader"));
+    ScalarHeader = VPIRBasicBlock::fromBasicBlock(ScalarHeaderBB);
+    TripCount = TC;
+  }
+
   ~VPlan();
 
   void setEntry(VPBasicBlock *VPBB) {
diff --git a/llvm/lib/Transforms/Vectorize/VPlanHCFGBuilder.h b/llvm/lib/Transforms/Vectorize/VPlanHCFGBuilder.h
index 9e8f9f3f400293..ad6e2ad90a9610 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanHCFGBuilder.h
+++ b/llvm/lib/Transforms/Vectorize/VPlanHCFGBuilder.h
@@ -32,11 +32,11 @@ class Loop;
 class LoopInfo;
 class VPRegionBlock;
 class VPlan;
-class VPlanTestBase;
+class VPlanTestIRBase;
 
 /// Main class to build the VPlan H-CFG for an incoming IR.
 class VPlanHCFGBuilder {
-  friend VPlanTestBase;
+  friend VPlanTestIRBase;
 
 private:
   // The outermost loop of the input loop nest considered for vectorization.
diff --git a/llvm/unittests/Transforms/Vectorize/VPDomTreeTest.cpp b/llvm/unittests/Transforms/Vectorize/VPDomTreeTest.cpp
index 847cca7714effc..6aa34a5fa431b5 100644
--- a/llvm/unittests/Transforms/Vectorize/VPDomTreeTest.cpp
+++ b/llvm/unittests/Transforms/Vectorize/VPDomTreeTest.cpp
@@ -9,12 +9,15 @@
 
 #include "../lib/Transforms/Vectorize/VPlan.h"
 #include "../lib/Transforms/Vectorize/VPlanDominatorTree.h"
+#include "VPlanTestBase.h"
 #include "gtest/gtest.h"
 
 namespace llvm {
 namespace {
 
-TEST(VPDominatorTreeTest, DominanceNoRegionsTest) {
+using VPDominatorTreeTest = VPlanTestBase;
+
+TEST_F(VPDominatorTreeTest, DominanceNoRegionsTest) {
   //   VPBB0
   //    |
   //   R1 {
@@ -24,8 +27,8 @@ TEST(VPDominatorTreeTest, DominanceNoRegionsTest) {
   //    \    /
   //    VPBB4
   //  }
-  VPBasicBlock *VPPH = new VPBasicBlock("ph");
-  VPBasicBlock *VPBB0 = new VPBasicBlock("VPBB0");
+  VPlan &Plan = getPlan();
+  VPBasicBlock *VPBB0 = Plan.getEntry();
   VPBasicBlock *VPBB1 = new VPBasicBlock("VPBB1");
   VPBasicBlock *VPBB2 = new VPBasicBlock("VPBB2");
   VPBasicBlock *VPBB3 = new VPBasicBlock("VPBB3");
@@ -40,12 +43,7 @@ TEST(VPDominatorTreeTest, DominanceNoRegionsTest) {
   VPBlockUtils::connectBlocks(VPBB2, VPBB4);
   VPBlockUtils::connectBlocks(VPBB3, VPBB4);
 
-  LLVMContext C;
-  auto *ScalarHeader = BasicBlock::Create(C, "");
-  VPIRBasicBlock *ScalarHeaderVPBB = new VPIRBasicBlock(ScalarHeader);
-  VPBlockUtils::connectBlocks(R1, ScalarHeaderVPBB);
-  VPBlockUtils::connectBlocks(VPPH, VPBB0);
-  VPlan Plan(VPPH, ScalarHeaderVPBB);
+  VPBlockUtils::connectBlocks(R1, Plan.getScalarHeader());
 
   VPDominatorTree VPDT;
   VPDT.recalculate(Plan);
@@ -62,7 +60,6 @@ TEST(VPDominatorTreeTest, DominanceNoRegionsTest) {
   EXPECT_EQ(VPDT.findNearestCommonDominator(VPBB2, VPBB3), VPBB1);
   EXPECT_EQ(VPDT.findNearestCommonDominator(VPBB2, VPBB4), VPBB1);
   EXPECT_EQ(VPDT.findNearestCommonDominator(VPBB4, VPBB4), VPBB4);
-  delete ScalarHeader;
 }
 
 static void
@@ -76,9 +73,7 @@ checkDomChildren(VPDominatorTree &VPDT, VPBlockBase *Src,
   EXPECT_EQ(Children, ExpectedNodes);
 }
 
-TEST(VPDominatorTreeTest, DominanceRegionsTest) {
-  LLVMContext C;
-  auto *ScalarHeader = BasicBlock::Create(C, "");
+TEST_F(VPDominatorTreeTest, DominanceRegionsTest) {
   {
     // 2 consecutive regions.
     // VPBB0
@@ -99,8 +94,8 @@ TEST(VPDominatorTreeTest, DominanceRegionsTest) {
     //    R2BB2
     // }
     //
-    VPBasicBlock *VPPH = new VPBasicBlock("ph");
-    VPBasicBlock *VPBB0 = new VPBasicBlock("VPBB0");
+    VPlan &Plan = getPlan();
+    VPBasicBlock *VPBB0 = Plan.getEntry();
     VPBasicBlock *R1BB1 = new VPBasicBlock();
     VPBasicBlock *R1BB2 = new VPBasicBlock();
     VPBasicBlock *R1BB3 = new VPBasicBlock();
@@ -122,10 +117,7 @@ TEST(VPDominatorTreeTest, DominanceRegionsTest) {
     VPBlockUtils::connectBlocks(R2BB1, R2BB2);
     VPBlockUtils::connectBlocks(R1, R2);
 
-    VPIRBasicBlock *ScalarHeaderVPBB = new VPIRBasicBlock(ScalarHeader);
-    VPBlockUtils::connectBlocks(R2, ScalarHeaderVPBB);
-    VPBlockUtils::connectBlocks(VPPH, VPBB0);
-    VPlan Plan(VPPH, ScalarHeaderVPBB);
+    VPBlockUtils::connectBlocks(R2, Plan.getScalarHeader());
     VPDominatorTree VPDT;
     VPDT.recalculate(Plan);
 
@@ -177,7 +169,7 @@ TEST(VPDominatorTreeTest, DominanceRegionsTest) {
     //   |
     //  VPBB2
     //
-    VPBasicBlock *VPPH = new VPBasicBlock("ph");
+    VPlan &Plan = getPlan();
     VPBasicBlock *R1BB1 = new VPBasicBlock("R1BB1");
     VPBasicBlock *R1BB2 = new VPBasicBlock("R1BB2");
     VPBasicBlock *R1BB3 = new VPBasicBlock("R1BB3");
@@ -199,15 +191,12 @@ TEST(VPDominatorTreeTest, DominanceRegionsTest) {
     VPBlockUtils::connectBlocks(R1BB2, R1BB3);
     VPBlockUtils::connectBlocks(R2, R1BB3);
 
-    VPBasicBlock *VPBB1 = new VPBasicBlock("VPBB1");
+    VPBasicBlock *VPBB1 = Plan.getEntry();
     VPBlockUtils::connectBlocks(VPBB1, R1);
     VPBasicBlock *VPBB2 = new VPBasicBlock("VPBB2");
     VPBlockUtils::connectBlocks(R1, VPBB2);
 
-    VPIRBasicBlock *ScalarHeaderVPBB = new VPIRBasicBlock(ScalarHeader);
-    VPBlockUtils::connectBlocks(VPBB2, ScalarHeaderVPBB);
-    VPBlockUtils::connectBlocks(VPPH, VPBB1);
-    VPlan Plan(VPPH, ScalarHeaderVPBB);
+    VPBlockUtils::connectBlocks(VPBB2, Plan.getScalarHeader());
     VPDominatorTree VPDT;
     VPDT.recalculate(Plan);
 
@@ -220,9 +209,8 @@ TEST(VPDominatorTreeTest, DominanceRegionsTest) {
     checkDomChildren(VPDT, R2BB2, {R2BB3});
     checkDomChildren(VPDT, R2BB3, {});
     checkDomChildren(VPDT, R1BB3, {VPBB2});
-    checkDomChildren(VPDT, VPBB2, {ScalarHeaderVPBB});
+    checkDomChildren(VPDT, VPBB2, {Plan.getScalarHeader()});
   }
-  delete ScalarHeader;
 }
 
 } // namespace
diff --git a/llvm/unittests/Transforms/Vectorize/VPlanHCFGTest.cpp b/llvm/unittests/Transforms/Vectorize/VPlanHCFGTest.cpp
index 1b362d1d26bdd3..19c2483d34ed17 100644
--- a/llvm/unittests/Transforms/Vectorize/VPlanHCFGTest.cpp
+++ b/llvm/unittests/Transforms/Vectorize/VPlanHCFGTest.cpp
@@ -17,7 +17,7 @@
 namespace llvm {
 namespace {
 
-class VPlanHCFGTest : public VPlanTestBase {};
+class VPlanHCFGTest : public VPlanTestIRBase {};
 
 TEST_F(VPlanHCFGTest, testBuildHCFGInnerLoop) {
   const char *ModuleString =
diff --git a/llvm/unittests/Transforms/Vectorize/VPlanSlpTest.cpp b/llvm/unittests/Transforms/Vectorize/VPlanSlpTest.cpp
index 1b993b63898caa..e3c542ec5cac85 100644
--- a/llvm/unittests/Transforms/Vectorize/VPlanSlpTest.cpp
+++ b/llvm/unittests/Transforms/Vectorize/VPlanSlpTest.cpp
@@ -16,7 +16,7 @@
 namespace llvm {
 namespace {
 
-class VPlanSlpTest : public VPlanTestBase {
+class VPlanSlpTest : public VPlanTestIRBase {
 protected:
   TargetLibraryInfoImpl TLII;
   TargetLibraryInfo TLI;
diff --git a/llvm/unittests/Transforms/Vectorize/VPlanTest.cpp b/llvm/unittests/Transforms/Vectorize/VPlanTest.cpp
index f3a1bba518c83c..2ab55f64a20730 100644
--- a/llvm/unittests/Transforms/Vectorize/VPlanTest.cpp
+++ b/llvm/unittests/Transforms/Vectorize/VPlanTest.cpp
@@ -9,6 +9,7 @@
 
 #include "../lib/Transforms/Vectorize/VPlan.h"
 #include "../lib/Transforms/Vectorize/VPlanCFG.h"
+#include "VPlanTestBase.h"
 #include "llvm/ADT/DepthFirstIterator.h"
 #include "llvm/ADT/PostOrderIterator.h"
 #include "llvm/Analysis/VectorUtils.h"
@@ -237,12 +238,13 @@ TEST(VPInstructionTest, releaseOperandsAtDeletion) {
   delete VPV1;
   delete VPV2;
 }
-TEST(VPBasicBlockTest, getPlan) {
-  LLVMContext C;
-  auto *ScalarHeader = BasicBlock::Create(C, "");
+
+using VPBasicBlockTest = VPlanTestBase;
+
+TEST_F(VPBasicBlockTest, getPlan) {
   {
-    VPBasicBlock *VPPH = new VPBasicBlock("ph");
-    VPBasicBlock *VPBB1 = new VPBasicBlock();
+    VPlan &Plan = getPlan();
+    VPBasicBlock *VPBB1 = Plan.getEntry();
     VPBasicBlock *VPBB2 = new VPBasicBlock();
     VPBasicBlock *VPBB3 = new VPBasicBlock();
     VPBasicBlock *VPBB4 = new VPBasicBlock();
@@ -256,11 +258,7 @@ TEST(VPBasicBlockTest, getPlan) {
     VPBlockUtils::connectBlocks(VPBB1, VPBB3);
     VPBlockUtils::connectBlocks(VPBB2, VPBB4);
     VPBlockUtils::connectBlocks(VPBB3, VPBB4);
-
-    VPIRBasicBlock *ScalarHeaderVPBB = new VPIRBasicBlock(ScalarHeader);
-    VPBlockUtils::connectBlocks(VPBB4, ScalarHeaderVPBB);
-    VPBlockUtils::connectBlocks(VPPH, VPBB1);
-    VPlan Plan(VPPH, ScalarHeaderVPBB);
+    VPBlockUtils::connectBlocks(VPBB4, Plan.getScalarHeader());
 
     EXPECT_EQ(&Plan, VPBB1->getPlan());
     EXPECT_EQ(&Plan, VPBB2->getPlan());
@@ -269,20 +267,17 @@ TEST(VPBasicBlockTest, getPlan) {
   }
 
   {
-    VPBasicBlock *VPPH = new VPBasicBlock("ph");
+    VPlan &Plan = getPlan();
+    VPBasicBlock *VPBB1 = Plan.getEntry();
     // VPBasicBlock is the entry into the VPlan, followed by a region.
     VPBasicBlock *R1BB1 = new VPBasicBlock();
     VPBasicBlock *R1BB2 = new VPBasicBlock();
     VPRegionBlock *R1 = new VPRegionBlock(R1BB1, R1BB2, "R1");
     VPBlockUtils::connectBlocks(R1BB1, R1BB2);
 
-    VPBasicBlock *VPBB1 = new VPBasicBlock();
     VPBlockUtils::connectBlocks(VPBB1, R1);
 
-    VPIRBasicBlock *ScalarHeaderVPBB = new VPIRBasicBlock(ScalarHeader);
-    VPBlockUtils::connectBlocks(R1, ScalarHeaderVPBB);
-    VPBlockUtils::connectBlocks(VPPH, VPBB1);
-    VPlan Plan(VPPH, ScalarHeaderVPBB);
+    VPBlockUtils::connectBlocks(R1, Plan.getScalarHeader());
 
     EXPECT_EQ(&Plan, VPBB1->getPlan());
     EXPECT_EQ(&Plan, R1->getPlan());
@@ -291,8 +286,7 @@ TEST(VPBasicBlockTest, getPlan) {
   }
 
   {
-    VPBasicBlock *VPPH = new VPBasicBlock("ph");
-
+    VPlan &Plan = getPlan();
     VPBasicBlock *R1BB1 = new VPBasicBlock();
     VPBasicBlock *R1BB2 = new VPBasicBlock();
     VPRegionBlock *R1 = new VPRegionBlock(R1BB1, R1BB2, "R1");
@@ -303,7 +297,7 @@ TEST(VPBasicBlockTest, getPlan) {
     VPRegionBlock *R2 = new VPRegionBlock(R2BB1, R2BB2, "R2");
     VPBlockUtils::connectBlocks(R2BB1, R2BB2);
 
-    VPBasicBlock *VPBB1 = new VPBasicBlock();
+    VPBasicBlock *VPBB1 = Plan.getEntry();
     VPBlockUtils::connectBlocks(VPBB1, R1);
     VPBlockUtils::connectBlocks(VPBB1, R2);
 
@@ -311,10 +305,7 @@ TEST(VPBasicBlockTest, getPlan) {
     VPBlockUtils::connectBlocks(R1, VPBB2);
     VPBlockUtils::connectBlocks(R2, VPBB2);
 
-    VPIRBasicBlock *ScalarHeaderVPBB = new VPIRBasicBlock(ScalarHeader);
-    VPBlockUtils::connectBlocks(R2, ScalarHeaderVPBB);
-    VPBlockUtils::connectBlocks(VPPH, VPBB1);
-    VPlan Plan(VPPH, ScalarHeaderVPBB);
+    VPBlockUtils::connectBlocks(R2, Plan.getScalarHeader());
 
     EXPECT_EQ(&Plan, VPBB1->getPlan());
     EXPECT_EQ(&Plan, R1->getPlan());
@@ -325,12 +316,9 @@ TEST(VPBasicBlockTest, getPlan) {
     EXPECT_EQ(&Plan, R2BB2->getPlan());
     EXPECT_EQ(&Plan, VPBB2->getPlan());
   }
-  delete ScalarHeader;
 }
 
-TEST(VPBasicBlockTest, TraversingIteratorTest) {
-  LLVMContext C;
-  auto *ScalarHeader = BasicBlock::Create(C, "");
+TEST_F(VPBasicBlockTest, TraversingIteratorTest) {
   {
     // VPBasicBlocks only
     //     VPBB1
@@ -339,8 +327,8 @@ TEST(VPBasicBlockTest, TraversingIteratorTest) {
     //    \    /
     //    VPBB4
     //
-    VPBasicBlock *VPPH = new VPBasicBlock("ph");
-    VPBasicBlock *VPBB1 = new VPBasicBlock();
+    VPlan &Plan = getPlan();
+    VPBasicBlock *VPBB1 = Plan.getEntry();
     VPBasicBlock *VPBB2 = new VPBasicBlock();
     VPBasicBlock *VPBB3 = new VPBasicBlock();
     VPBasicBlock *VPBB4 = new VPBasicBlock();
@@ -356,11 +344,7 @@ TEST(VPBasicBlockTest, TraversingIteratorTest) {
     EXPECT_EQ(VPBB1, FromIterator[0]);
     EXPECT_EQ(VPBB2, FromIterator[1]);
 
-    // Use Plan to properly clean up created blocks.
-    VPIRBasicBlock *ScalarHeaderVPBB = new VPIRBasicBlock(ScalarHeader);
-    VPBlockUtils::connectBlocks(VPBB4, ScalarHeaderVPBB);
-    VPBlockUtils::connectBlocks(VPPH, VPBB1);
-    VPlan Plan(VPPH, ScalarHeaderVPBB);
+    VPBlockUtils::connectBlocks(VPBB4, Plan.getScalarHeader());
   }
 
   {
@@ -382,8 +366,8 @@ TEST(VPBasicBlockTest, TraversingIteratorTest) {
     //      |
     //    R2BB2
     //
-    VPBasicBlock *VPPH = new VPBasicBlock("ph");
-    VPBasicBlock *VPBB0 = new VPBasicBlock("VPBB0");
+    VPlan &Plan = getPlan();
+    VPBasicBlock *VPBB0 = Plan.getEntry();
     VPBasicBlock *R1BB1 = new VPBasicBlock();
     VPBasicBlock *R1BB2 = new VPBasicBlock();
     VPBasicBlock *R1BB3 = new VPBasicBlock();
@@ -458,11 +442,7 @@ TEST(VPBasicBlockTest, TraversingIteratorTest) {
     EXPECT_EQ(R1BB1, FromIterator[6]);
     EXPECT_EQ(R1, FromIterator[7]);
 
-    // Use Plan to properly clean up created blocks.
-    VPIRBasicBlock *ScalarHeaderVPBB = new VPIRBasicBlock(ScalarHeader);
-    VPBlockUtils::connectBlocks(R2, ScalarHeaderVPBB);
-    VPBlockUtils::connectBlocks(VPPH, VPBB0);
-    VPlan Plan(VPPH, ScalarHeaderVPBB);
+    VPBlockUtils::connectBlocks(R2, Plan.getScalarHeader());
   }
 
   {
@@ -486,7 +466,7 @@ TEST(VPBasicBlockTest, TraversingIteratorTest) {
     //   |
     //  VPBB2
     //
-    VPBasicBlock *VPPH = new VPBasicBlock("ph");
+    VPlan &Plan = getPlan();
     VPBasicBlock *R1BB1 = new VPBasicBlock("R1BB1");
     VPBasicBlock *R1BB2 = new VPBasicBlock("R1BB2");
     VPBasicBlock *R1BB3 = new VPBasicBlock("R1BB3");
@@ -508,7 +488,7 @@ TEST(VPBasicBlockTest, TraversingIteratorTest) {
     VPBlockUtils::connectBlocks(R1BB2, R1BB3);
     VPBlockUtils::connectBlocks(R2, R1BB3);
 
-    VPBasicBlock *VPBB1 = new VPBasicBlock("VPBB1");
+    VPBasicBlock *VPBB1 = Plan.getEntry();
     VPBlockUtils::connectBlocks(VPBB1, R1);
     VPBasicBlock *VPBB2 = new VPBasicBlock("VPBB2");
     VPBlockUtils::connectBlocks(R1, VPBB2);
@@ -543,11 +523,7 @@ TEST(VPBasicBlockTest, TraversingIteratorTest) {
     EXPECT_EQ(R1, FromIterator[8]);
     EXPECT_EQ(VPBB1, FromIterator[9]);
 
-    // Use Plan to properly clean up created blocks.
-    VPIRBasicBlock *ScalarHeaderVPBB = new VPIRBasicBlock(ScalarHeader);
-    VPBlockUtils::connectBlocks(VPBB2, ScalarHeaderVPBB);
-    VPBlockUtils::connectBlocks(VPPH, VPBB1);
-    VPlan Plan(VPPH, ScalarHeaderVPBB);
+    VPBlockUtils::connectBlocks(VPBB2, Plan.getScalarHeader());
   }
 
   {
@@ -561,7 +537,7 @@ TEST(VPBasicBlockTest, TraversingIteratorTest) {
     //      R2BB2
     //   }
     //
-    VPBasicBlock *VPPH = new VPBasicBlock("ph");
+    VPlan &Plan = getPlan();
     VPBasicBlock *R2BB1 = new VPBasicBlock("R2BB1");
     VPBasicBlock *R2BB2 = new VPBasicBlock("R2BB2");
     VPRegionBlock *R2 = new VPRegionBlock(R2BB1, R2BB2, "R2");
@@ -570,7 +546,7 @@ TEST(VPBasicBlockTest, TraversingIteratorTest) {
     VPRegionBlock *R1 = new VPRegionBlock(R2, R2, "R1");
     R2->setParent(R1);
 
-    VPBasicBlock *VPBB1 = new VPBasicBlock("VPBB1");
+    VPBasicBlock *VPBB1 = Plan.getEntry();
     VPBlockUtils::connectBlocks(VPBB1, R1);
 
     // Depth-first.
@@ -593,11 +569,7 @@ TEST(VPBasicBlockTest, TraversingIteratorTest) {
     EXPECT_EQ(R1, FromIterator[3]);
     EXPECT_EQ(VPBB1, FromIterator[4]);
 
-    // Use Plan to properly clean up created blocks.
-    VPIRBasicBlock *ScalarHeaderVPBB = new VPIRBasicBlock(ScalarHeader);
-    VPBlockUtils::connectBlocks(R1, ScalarHeaderVPBB);
-    VPBlockUtils::connectBlocks(VPPH, VPBB1);
-    VPlan Plan(VPPH, ScalarHeaderVPBB);
+    VPBlockUtils::connectBlocks(R1, Plan.getScalarHeader());
   }
 
   {
@@ -619,7 +591,7 @@ TEST(VPBasicBlockTest, TraversingIteratorTest) {
     //   |
     //  VPBB2
     //
-    VPBasicBlock *VPPH = new VPBasicBlock("ph");
+    VPlan &Plan = getPlan();
     VPBasicBlock *R3BB1 = new VPBasicBlock("R3BB1");
     VPRegionBlock *R3 = new VPRegionBlock(R3BB1, R3BB1, "R3");
 
@@ -631,7 +603,7 @@ TEST(VPBasicBlockTest, TraversingIteratorTest) {
     VPRegionBlock *R1 = new VPRegionBlock(R2, R2, "R1");
     R2->setParent(R1);
 
-    VPBasicBlock *VPBB1 = new VPBasicBlock("VPBB1");
+    VPBasicBlock *VPBB1 = Plan.getEntry();
     VPBasicBlock *VPBB2 = new VPBasicBlock("VPBB2");
     VPBlockUtils::connectBlocks(VPBB1, R1);
     VPBlockUtils::connectBlocks(R1, VPBB2);
@@ -687,19 +659,15 @@ TEST(VPBasicBlockTest, TraversingIteratorTest) {
     EXPECT_EQ(R2BB1, FromIterator[2]);
     EXPECT_EQ(VPBB1, FromIterator[3]);
 
-    // Use Plan to properly clean up created blocks.
-    VPIRBasicBlock *ScalarHeaderVPBB = new VPIRBasicBlock(ScalarHeader);
-    VPBlockUtils::connectBlocks(VPBB2, ScalarHeaderVPBB);
-    VPBlockUtils::connectBlocks(VPPH, VPBB1);
-    VPlan Plan(VPPH, ScalarHeaderVPBB);
+    VPBlockUtils::connectBlocks(VPBB2, Plan.getScalarHeader());
   }
-  delete ScalarHeader;
 }
 
 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
-TEST(VPBasicBlockTest, print) {
+TEST_F(VPBasicBlockTest, print) {
   VPInstruction *TC = new VPInstruction(Instruction::Add, {});
-  VPBasicBlock *VPBB0 = new VPBasicBlock("preheader");
+  VPlan &Plan = getPlan(TC);
+  VPBasicBlock *VPBB0 = Plan.getEntry();
   VPBB0->appendRecipe(TC);
 
   VPInstruction *I1 = new VPInstruction(Instruction::Add, {});
@@ -730,12 +698,8 @@ TEST(VPBasicBlockTest, print) {
     EXPECT_EQ("EMIT br <badref>, <badref>", I3Dump);
   }
 
-  LLVMContext C;
-  auto *ScalarHeader = BasicBlock::Create(C, "scalar.header");
-  auto * ScalarHeaderVPBB = new VPIRBasicBlock(ScalarHeader);
-  VPBlockUtils::connectBlocks(VPBB2, ScalarHeaderVPBB);
+  VPBlockUtils::connectBlocks(VPBB2, Plan.getScalarHeader());
   VPBlockUtils::connectBlocks(VPBB0, VPBB1);
-  VPlan Plan(VPBB0, TC, ScalarHeaderVPBB);
   std::string FullDump;
   raw_string_ostream OS(FullDump);
   Plan.printDOT(OS);
@@ -810,13 +774,12 @@ Successor(s): ir-bb<scalar.header>
     OS << *I4;
     EXPECT_EQ("EMIT vp<%5> = mul vp<%3>, vp<%2>", I4Dump);
   }
-  delete ScalarHeader;
 }
 
-TEST(VPBasicBlockTest, printPlanWithVFsAndUFs) {
-
+TEST_F(VPBasicBlockTest, printPlanWithVFsAndUFs) {
   VPInstruction *TC = new VPInstruction(Instruction::Sub, {});
-  VPBasicBlock *VPBB0 = new VPBasicBlock("preheader");
+  VPlan &Plan = getPlan(TC);
+  VPBasicBlock *VPBB0 = Plan.getEntry();
   VPBB0->appendRecipe(TC);
 
   VPInstruction *I1 = new VPInstruction(Instruction::Add, {});
@@ -824,12 +787,8 @@ TEST(VPBasicBlockTest, printPlanWithVFsAndUFs) {
   VPBB1->appendRecipe(I1);
   VPBB1->setName("bb1");
 
-  LLVMContext C;
-  auto *ScalarHeader = BasicBlock::Create(C, "");
-  VPIRBasicBlock *ScalarHeaderVPBB = new VPIRBasicBlock(ScalarHeader);
-  VPBlockUtils::connectBlocks(VPBB1, ScalarHeaderVPBB);
+  VPBlockUtils::connectBlocks(VPBB1, Plan.getScalarHeader());
   VPBlockUtils::connectBlocks(VPBB0, VPBB1);
-  VPlan Plan(VPBB0, TC, ScalarHeaderVPBB);
   Plan.setName("TestPlan");
   Plan.addVF(ElementCount::getFixed(4));
 
@@ -847,9 +806,9 @@ Successor(s): bb1
 
 bb1:
   EMIT vp<%2> = add
-Successor(s): ir-bb<>
+Successor(s): ir-bb<scalar.header>
 
-ir-bb<>:
+ir-bb<scalar.header>:
 No successors
 }
 )";
@@ -871,9 +830,9 @@ Successor(s): bb1
 
 bb1:
   EMIT vp<%2> = add
-Successor(s): ir-bb<>
+Successor(s): ir-bb<scalar.header>
 
-ir-bb<>:
+ir-bb<scalar.header>:
 No successors
 }
 )";
@@ -895,19 +854,19 @@ Successor(s): bb1
 
 bb1:
   EMIT vp<%2> = add
-Successor(s): ir-bb<>
+Successor(s): ir-bb<scalar.header>
 
-ir-bb<>:
+ir-bb<scalar.header>:
 No successors
 }
 )";
     EXPECT_EQ(ExpectedStr, FullDump);
   }
-  delete ScalarHeader;
 }
 #endif
 
-TEST(VPRecipeTest, CastVPInstructionToVPUser) {
+using VPRecipeTest = VPlanTestBase;
+TEST_F(VPRecipeTest, CastVPInstructionToVPUser) {
   VPValue Op1;
   VPValue Op2;
   VPInstruction Recipe(Instruction::Add, {&Op1, &Op2});
@@ -917,9 +876,7 @@ TEST(VPRecipeTest, CastVPInstructionToVPUser) {
   EXPECT_EQ(&Recipe, BaseR);
 }
 
-TEST(VPRecipeTest, CastVPWidenRecipeToVPUser) {
-  LLVMContext C;
-
+TEST_F(VPRecipeTest, CastVPWidenRecipeToVPUser) {
   IntegerType *Int32 = IntegerType::get(C, 32);
   auto *AI = BinaryOperator::CreateAdd(PoisonValue::get(Int32),
                                        PoisonValue::get(Int32));
@@ -936,9 +893,7 @@ TEST(VPRecipeTest, CastVPWidenRecipeToVPUser) {
   delete AI;
 }
 
-TEST(VPRecipeTest, CastVPWidenCallRecipeToVPUserAndVPDef) {
-  LLVMContext C;
-
+TEST_F(VPRecipeTest, CastVPWidenCallRecipeToVPUserAndVPDef) {
   IntegerType *Int32 = IntegerType::get(C, 32);
   FunctionType *FTy = FunctionType::get(Int32, false);
   Function *Fn = Function::Create(FTy, GlobalValue::ExternalLinkage, 0);
@@ -964,9 +919,7 @@ TEST(VPRecipeTest, CastVPWidenCallRecipeToVPUserAndVPDef) {
   delete Fn;
 }
 
-TEST(VPRecipeTest, CastVPWidenSelectRecipeToVPUserAndVPDef) {
-  LLVMContext C;
-
+TEST_F(VPRecipeTest, CastVPWidenSelectRecipeToVPUserAndVPDef) {
   IntegerType *Int1 = IntegerType::get(C, 1);
   IntegerType *Int32 = IntegerType::get(C, 32);
   auto *SelectI = SelectInst::Create(
@@ -992,9 +945,7 @@ TEST(VPRecipeTest, CastVPWidenSelectRecipeToVPUserAndVPDef) {
   delete SelectI;
 }
 
-TEST(VPRecipeTest, CastVPWidenGEPRecipeToVPUserAndVPDef) {
-  LLVMContext C;
-
+TEST_F(VPRecipeTest, CastVPWidenGEPRecipeToVPUserAndVPDef) {
   IntegerType *Int32 = IntegerType::get(C, 32);
   PointerType *Int32Ptr = PointerType::get(Int32, 0);
   auto *GEP = GetElementPtrInst::Create(Int32, PoisonValue::get(Int32Ptr),
@@ -1017,9 +968,7 @@ TEST(VPRecipeTest, CastVPWidenGEPRecipeToVPUserAndVPDef) {
   delete GEP;
 }
 
-TEST(VPRecipeTest, CastVPBlendRecipeToVPUser) {
-  LLVMContext C;
-
+TEST_F(VPRecipeTest, CastVPBlendRecipeToVPUser) {
   IntegerType *Int32 = IntegerType::get(C, 32);
   auto *Phi = PHINode::Create(Int32, 1);
   VPValue I1;
@@ -1036,9 +985,7 @@ TEST(VPRecipeTest, CastVPBlendRecipeToVPUser) {
   delete Phi;
 }
 
-TEST(VPRecipeTest, CastVPInterleaveRecipeToVPUser) {
-  LLVMContext C;
-
+TEST_F(VPRecipeTest, CastVPInterleaveRecipeToVPUser) {
   VPValue Addr;
   VPValue Mask;
   InterleaveGroup<Instruction> IG(4, false, Align(4));
@@ -1049,9 +996,7 @@ TEST(VPRecipeTest, CastVPInterleaveRecipeToVPUser) {
   EXPECT_EQ(&Recipe, BaseR);
 }
 
-TEST(VPRecipeTest, CastVPReplicateRecipeToVPUser) {
-  LLVMContext C;
-
+TEST_F(VPRecipeTest, CastVPReplicateRecipeToVPUser) {
   VPValue Op1;
   VPValue Op2;
   SmallVector<VPValue *, 4> Args;
@@ -1068,9 +1013,7 @@ TEST(VPRecipeTest, CastVPReplicateRecipeToVPUser) {
   delete Call;
 }
 
-TEST(VPRecipeTest, CastVPBranchOnMaskRecipeToVPUser) {
-  LLVMContext C;
-
+TEST_F(VPRecipeTest, CastVPBranchOnMaskRecipeToVPUser) {
   VPValue Mask;
   VPBranchOnMaskRecipe Recipe(&Mask);
   EXPECT_TRUE(isa<VPUser>(&Recipe));
@@ -1079,9 +1022,7 @@ TEST(VPRecipeTest, CastVPBranchOnMaskRecipeToVPUser) {
   EXPECT_EQ(&Recipe, BaseR);
 }
 
-TEST(VPRecipeTest, CastVPWidenMemoryRecipeToVPUserAndVPDef) {
-  LLVMContext C;
-
+TEST_F(VPRecipeTest, CastVPWidenMemoryRecipeToVPUserAndVPDef) {
   IntegerType *Int32 = IntegerType::get(C, 32);
   PointerType *Int32Ptr = PointerType::get(Int32, 0);
   auto *Load =
@@ -1101,8 +1042,7 @@ TEST(VPRecipeTest, CastVPWidenMemoryRecipeToVPUserAndVPDef) {
   delete Load;
 }
 
-TEST(VPRecipeTest, MayHaveSideEffectsAndMayReadWriteMemory) {
-  LLVMContext C;
+TEST_F(VPRecipeTest, MayHaveSideEffectsAndMayReadWriteMemory) {
   IntegerType *Int1 = IntegerType::get(C, 1);
   IntegerType *Int32 = IntegerType::get(C, 32);
   PointerType *Int32Ptr = PointerType::get(Int32, 0);
@@ -1242,7 +1182,6 @@ TEST(VPRecipeTest, MayHaveSideEffectsAndMayReadWriteMemory) {
 
   {
     // Test for a call to a function without side-effects.
-    LLVMContext C;
     Module M("", C);
     Function *TheFn =
         Intrinsic::getOrInsertDeclaration(&M, Intrinsic::thread_pointer);
@@ -1296,15 +1235,12 @@ TEST(VPRecipeTest, MayHaveSideEffectsAndMayReadWriteMemory) {
 }
 
 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
-TEST(VPRecipeTest, dumpRecipeInPlan) {
-  VPBasicBlock *VPBB0 = new VPBasicBlock("preheader");
+TEST_F(VPRecipeTest, dumpRecipeInPlan) {
+  VPlan &Plan = getPlan();
+  VPBasicBlock *VPBB0 = Plan.getEntry();
   VPBasicBlock *VPBB1 = new VPBasicBlock();
-  LLVMContext C;
-  auto *ScalarHeader = BasicBlock::Create(C, "");
-  VPIRBasicBlock *ScalarHeaderVPBB = new VPIRBasicBlock(ScalarHeader);
-  VPBlockUtils::connectBlocks(VPBB1, ScalarHeaderVPBB);
+  VPBlockUtils::connectBlocks(VPBB1, Plan.getScalarHeader());
   VPBlockUtils::connectBlocks(VPBB0, VPBB1);
-  VPlan Plan(VPBB0, ScalarHeaderVPBB);
 
   IntegerType *Int32 = IntegerType::get(C, 32);
   auto *AI = BinaryOperator::CreateAdd(PoisonValue::get(Int32),
@@ -1366,18 +1302,14 @@ TEST(VPRecipeTest, dumpRecipeInPlan) {
   }
 
   delete AI;
-  delete ScalarHeader;
 }
 
-TEST(VPRecipeTest, dumpRecipeUnnamedVPValuesInPlan) {
-  VPBasicBlock *VPBB0 = new VPBasicBlock("preheader");
+TEST_F(VPRecipeTest, dumpRecipeUnnamedVPValuesInPlan) {
+  VPlan &Plan = getPlan();
+  VPBasicBlock *VPBB0 = Plan.getEntry();
   VPBasicBlock *VPBB1 = new VPBasicBlock();
-  LLVMContext C;
-  auto *ScalarHeader = BasicBlock::Create(C, "");
-  VPIRBasicBlock *ScalarHeaderVPBB = new VPIRBasicBlock(ScalarHeader);
-  VPBlockUtils::connectBlocks(VPBB1, ScalarHeaderVPBB);
+  VPBlockUtils::connectBlocks(VPBB1, Plan.getScalarHeader());
   VPBlockUtils::connectBlocks(VPBB0, VPBB1);
-  VPlan Plan(VPBB0, ScalarHeaderVPBB);
 
   IntegerType *Int32 = IntegerType::get(C, 32);
   auto *AI = BinaryOperator::CreateAdd(PoisonValue::get(Int32),
@@ -1456,11 +1388,9 @@ TEST(VPRecipeTest, dumpRecipeUnnamedVPValuesInPlan) {
         testing::ExitedWithCode(0), "EMIT vp<%2> = mul vp<%1>, vp<%1>");
   }
   delete AI;
-  delete ScalarHeader;
 }
 
-TEST(VPRecipeTest, dumpRecipeUnnamedVPValuesNotInPlanOrBlock) {
-  LLVMContext C;
+TEST_F(VPRecipeTest, dumpRecipeUnnamedVPValuesNotInPlanOrBlock) {
   IntegerType *Int32 = IntegerType::get(C, 32);
   auto *AI = BinaryOperator::CreateAdd(PoisonValue::get(Int32),
                                        PoisonValue::get(Int32));
@@ -1543,9 +1473,7 @@ TEST(VPRecipeTest, dumpRecipeUnnamedVPValuesNotInPlanOrBlock) {
 
 #endif
 
-TEST(VPRecipeTest, CastVPReductionRecipeToVPUser) {
-  LLVMContext C;
-
+TEST_F(VPRecipeTest, CastVPReductionRecipeToVPUser) {
   VPValue ChainOp;
   VPValue VecOp;
   VPValue CondOp;
@@ -1556,9 +1484,7 @@ TEST(VPRecipeTest, CastVPReductionRecipeToVPUser) {
   EXPECT_TRUE(isa<VPUser>(BaseR));
 }
 
-TEST(VPRecipeTest, CastVPReductionEVLRecipeToVPUser) {
-  LLVMContext C;
-
+TEST_F(VPRecipeTest, CastVPReductionEVLRecipeToVPUser) {
   VPValue ChainOp;
   VPValue VecOp;
   VPValue CondOp;
@@ -1630,7 +1556,7 @@ TEST(VPDoubleValueDefTest, traverseUseLists) {
   EXPECT_EQ(&DoubleValueDef, I3.getOperand(0)->getDefiningRecipe());
 }
 
-TEST(VPRecipeTest, CastToVPSingleDefRecipe) {
+TEST_F(VPRecipeTest, CastToVPSingleDefRecipe) {
   VPValue Start;
   VPEVLBasedIVPHIRecipe R(&Start, {});
   VPRecipeBase *B = &R;
diff --git a/llvm/unittests/Transforms/Vectorize/VPlanTestBase.h b/llvm/unittests/Transforms/Vectorize/VPlanTestBase.h
index 06e091da9054e3..1836a5e39a290e 100644
--- a/llvm/unittests/Transforms/Vectorize/VPlanTestBase.h
+++ b/llvm/unittests/Transforms/Vectorize/VPlanTestBase.h
@@ -28,7 +28,7 @@ namespace llvm {
 
 /// Helper class to create a module from an assembly string and VPlans for a
 /// given loop entry block.
-class VPlanTestBase : public testing::Test {
+class VPlanTestIRBase : public testing::Test {
 protected:
   TargetLibraryInfoImpl TLII;
   TargetLibraryInfo TLI;
@@ -41,7 +41,7 @@ class VPlanTestBase : public testing::Test {
   std::unique_ptr<AssumptionCache> AC;
   std::unique_ptr<ScalarEvolution> SE;
 
-  VPlanTestBase()
+  VPlanTestIRBase()
       : TLII(), TLI(TLII),
         DL("e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-"
            "f64:64:64-v64:64:64-v128:128:128-a0:0:64-s0:64:64-f80:128:128-n8:"
@@ -92,6 +92,22 @@ class VPlanTestBase : public testing::Test {
   }
 };
 
+class VPlanTestBase : public testing::Test {
+protected:
+  LLVMContext C;
+  std::unique_ptr<BasicBlock> ScalarHeader;
+  SmallVector<std::unique_ptr<VPlan>> Plans;
+
+  VPlanTestBase() : ScalarHeader(BasicBlock::Create(C, "scalar.header")) {
+    BranchInst::Create(&*ScalarHeader, &*ScalarHeader);
+  }
+
+  VPlan &getPlan(VPValue *TC = nullptr) {
+    Plans.push_back(std::make_unique<VPlan>(&*ScalarHeader, TC));
+    return *Plans.back();
+  }
+};
+
 } // namespace llvm
 
 #endif // LLVM_UNITTESTS_TRANSFORMS_VECTORIZE_VPLANTESTBASE_H
diff --git a/llvm/unittests/Transforms/Vectorize/VPlanVerifierTest.cpp b/llvm/unittests/Transforms/Vectorize/VPlanVerifierTest.cpp
index 6448153de7821c..174249a7e85e32 100644
--- a/llvm/unittests/Transforms/Vectorize/VPlanVerifierTest.cpp
+++ b/llvm/unittests/Transforms/Vectorize/VPlanVerifierTest.cpp
@@ -8,32 +8,29 @@
 
 #include "../lib/Transforms/Vectorize/VPlanVerifier.h"
 #include "../lib/Transforms/Vectorize/VPlan.h"
+#include "VPlanTestBase.h"
 #include "llvm/IR/Instruction.h"
 #include "llvm/IR/Instructions.h"
 #include "gtest/gtest.h"
 
 using namespace llvm;
 
+using VPVerifierTest = VPlanTestBase;
+
 namespace {
-TEST(VPVerifierTest, VPInstructionUseBeforeDefSameBB) {
+TEST_F(VPVerifierTest, VPInstructionUseBeforeDefSameBB) {
+  VPlan &Plan = getPlan();
   VPInstruction *DefI = new VPInstruction(Instruction::Add, {});
   VPInstruction *UseI = new VPInstruction(Instruction::Sub, {DefI});
 
-  VPBasicBlock *VPPH = new VPBasicBlock("ph");
-  VPBasicBlock *VPBB1 = new VPBasicBlock();
+  VPBasicBlock *VPBB1 = Plan.getEntry();
   VPBB1->appendRecipe(UseI);
   VPBB1->appendRecipe(DefI);
 
   VPBasicBlock *VPBB2 = new VPBasicBlock();
   VPRegionBlock *R1 = new VPRegionBlock(VPBB2, VPBB2, "R1");
   VPBlockUtils::connectBlocks(VPBB1, R1);
-
-  LLVMContext C;
-  auto *ScalarHeader = BasicBlock::Create(C, "");
-  VPIRBasicBlock *ScalarHeaderVPBB = new VPIRBasicBlock(ScalarHeader);
-  VPBlockUtils::connectBlocks(R1, ScalarHeaderVPBB);
-  VPBlockUtils::connectBlocks(VPPH, VPBB1);
-  VPlan Plan(VPPH, ScalarHeaderVPBB);
+  VPBlockUtils::connectBlocks(R1, Plan.getScalarHeader());
 
 #if GTEST_HAS_STREAM_REDIRECTION
   ::testing::internal::CaptureStderr();
@@ -43,18 +40,17 @@ TEST(VPVerifierTest, VPInstructionUseBeforeDefSameBB) {
   EXPECT_STREQ("Use before def!\n",
                ::testing::internal::GetCapturedStderr().c_str());
 #endif
-  delete ScalarHeader;
 }
 
-TEST(VPVerifierTest, VPInstructionUseBeforeDefDifferentBB) {
+TEST_F(VPVerifierTest, VPInstructionUseBeforeDefDifferentBB) {
+  VPlan &Plan = getPlan();
   VPInstruction *DefI = new VPInstruction(Instruction::Add, {});
   VPInstruction *UseI = new VPInstruction(Instruction::Sub, {DefI});
   auto *CanIV = new VPCanonicalIVPHIRecipe(UseI, {});
   VPInstruction *BranchOnCond =
       new VPInstruction(VPInstruction::BranchOnCond, {CanIV});
 
-  VPBasicBlock *VPPH = new VPBasicBlock("ph");
-  VPBasicBlock *VPBB1 = new VPBasicBlock();
+  VPBasicBlock *VPBB1 = Plan.getEntry();
   VPBasicBlock *VPBB2 = new VPBasicBlock();
 
   VPBB1->appendRecipe(UseI);
@@ -64,13 +60,7 @@ TEST(VPVerifierTest, VPInstructionUseBeforeDefDifferentBB) {
 
   VPRegionBlock *R1 = new VPRegionBlock(VPBB2, VPBB2, "R1");
   VPBlockUtils::connectBlocks(VPBB1, R1);
-
-  LLVMContext C;
-  auto *ScalarHeader = BasicBlock::Create(C, "");
-  VPIRBasicBlock *ScalarHeaderVPBB = new VPIRBasicBlock(ScalarHeader);
-  VPBlockUtils::connectBlocks(R1, ScalarHeaderVPBB);
-  VPBlockUtils::connectBlocks(VPPH, VPBB1);
-  VPlan Plan(VPPH, ScalarHeaderVPBB);
+  VPBlockUtils::connectBlocks(R1, Plan.getScalarHeader());
 
 #if GTEST_HAS_STREAM_REDIRECTION
   ::testing::internal::CaptureStderr();
@@ -80,11 +70,9 @@ TEST(VPVerifierTest, VPInstructionUseBeforeDefDifferentBB) {
   EXPECT_STREQ("Use before def!\n",
                ::testing::internal::GetCapturedStderr().c_str());
 #endif
-  delete ScalarHeader;
 }
 
-TEST(VPVerifierTest, VPBlendUseBeforeDefDifferentBB) {
-  LLVMContext C;
+TEST_F(VPVerifierTest, VPBlendUseBeforeDefDifferentBB) {
   IntegerType *Int32 = IntegerType::get(C, 32);
   auto *Phi = PHINode::Create(Int32, 1);
 
@@ -95,8 +83,8 @@ TEST(VPVerifierTest, VPBlendUseBeforeDefDifferentBB) {
       new VPInstruction(VPInstruction::BranchOnCond, {CanIV});
   auto *Blend = new VPBlendRecipe(Phi, {DefI});
 
-  VPBasicBlock *VPPH = new VPBasicBlock("ph");
-  VPBasicBlock *VPBB1 = new VPBasicBlock();
+  VPlan &Plan = getPlan();
+  VPBasicBlock *VPBB1 = Plan.getEntry();
   VPBasicBlock *VPBB2 = new VPBasicBlock();
   VPBasicBlock *VPBB3 = new VPBasicBlock();
   VPBasicBlock *VPBB4 = new VPBasicBlock();
@@ -113,11 +101,7 @@ TEST(VPVerifierTest, VPBlendUseBeforeDefDifferentBB) {
   VPBlockUtils::connectBlocks(VPBB1, R1);
   VPBB3->setParent(R1);
 
-  auto *ScalarHeader = BasicBlock::Create(C, "");
-  VPIRBasicBlock *ScalarHeaderVPBB = new VPIRBasicBlock(ScalarHeader);
-  VPBlockUtils::connectBlocks(R1, ScalarHeaderVPBB);
-  VPBlockUtils::connectBlocks(VPPH, VPBB1);
-  VPlan Plan(VPPH, ScalarHeaderVPBB);
+  VPBlockUtils::connectBlocks(R1, Plan.getScalarHeader());
 
 #if GTEST_HAS_STREAM_REDIRECTION
   ::testing::internal::CaptureStderr();
@@ -129,10 +113,9 @@ TEST(VPVerifierTest, VPBlendUseBeforeDefDifferentBB) {
 #endif
 
   delete Phi;
-  delete ScalarHeader;
 }
 
-TEST(VPVerifierTest, DuplicateSuccessorsOutsideRegion) {
+TEST_F(VPVerifierTest, DuplicateSuccessorsOutsideRegion) {
   VPInstruction *I1 = new VPInstruction(Instruction::Add, {});
   auto *CanIV = new VPCanonicalIVPHIRecipe(I1, {});
   VPInstruction *BranchOnCond =
@@ -140,8 +123,8 @@ TEST(VPVerifierTest, DuplicateSuccessorsOutsideRegion) {
   VPInstruction *BranchOnCond2 =
       new VPInstruction(VPInstruction::BranchOnCond, {I1});
 
-  VPBasicBlock *VPPH = new VPBasicBlock("ph");
-  VPBasicBlock *VPBB1 = new VPBasicBlock();
+  VPlan &Plan = getPlan();
+  VPBasicBlock *VPBB1 = Plan.getEntry();
   VPBasicBlock *VPBB2 = new VPBasicBlock();
 
   VPBB1->appendRecipe(I1);
@@ -153,12 +136,7 @@ TEST(VPVerifierTest, DuplicateSuccessorsOutsideRegion) {
   VPBlockUtils::connectBlocks(VPBB1, R1);
   VPBlockUtils::connectBlocks(VPBB1, R1);
 
-  LLVMContext C;
-  auto *ScalarHeader = BasicBlock::Create(C, "");
-  VPIRBasicBlock *ScalarHeaderVPBB = new VPIRBasicBlock(ScalarHeader);
-  VPBlockUtils::connectBlocks(R1, ScalarHeaderVPBB);
-  VPBlockUtils::connectBlocks(VPPH, VPBB1);
-  VPlan Plan(VPPH, ScalarHeaderVPBB);
+  VPBlockUtils::connectBlocks(R1, Plan.getScalarHeader());
 
 #if GTEST_HAS_STREAM_REDIRECTION
   ::testing::internal::CaptureStderr();
@@ -168,10 +146,9 @@ TEST(VPVerifierTest, DuplicateSuccessorsOutsideRegion) {
   EXPECT_STREQ("Multiple instances of the same successor.\n",
                ::testing::internal::GetCapturedStderr().c_str());
 #endif
-  delete ScalarHeader;
 }
 
-TEST(VPVerifierTest, DuplicateSuccessorsInsideRegion) {
+TEST_F(VPVerifierTest, DuplicateSuccessorsInsideRegion) {
   VPInstruction *I1 = new VPInstruction(Instruction::Add, {});
   auto *CanIV = new VPCanonicalIVPHIRecipe(I1, {});
   VPInstruction *BranchOnCond =
@@ -179,8 +156,8 @@ TEST(VPVerifierTest, DuplicateSuccessorsInsideRegion) {
   VPInstruction *BranchOnCond2 =
       new VPInstruction(VPInstruction::BranchOnCond, {I1});
 
-  VPBasicBlock *VPPH = new VPBasicBlock("ph");
-  VPBasicBlock *VPBB1 = new VPBasicBlock();
+  VPlan &Plan = getPlan();
+  VPBasicBlock *VPBB1 = Plan.getEntry();
   VPBasicBlock *VPBB2 = new VPBasicBlock();
   VPBasicBlock *VPBB3 = new VPBasicBlock();
 
@@ -195,12 +172,7 @@ TEST(VPVerifierTest, DuplicateSuccessorsInsideRegion) {
   VPBlockUtils::connectBlocks(VPBB1, R1);
   VPBB3->setParent(R1);
 
-  LLVMContext C;
-  auto *ScalarHeader = BasicBlock::Create(C, "");
-  VPIRBasicBlock *ScalarHeaderVPBB = new VPIRBasicBlock(ScalarHeader);
-  VPBlockUtils::connectBlocks(R1, ScalarHeaderVPBB);
-  VPBlockUtils::connectBlocks(VPPH, VPBB1);
-  VPlan Plan(VPPH, ScalarHeaderVPBB);
+  VPBlockUtils::connectBlocks(R1, Plan.getScalarHeader());
 
 #if GTEST_HAS_STREAM_REDIRECTION
   ::testing::internal::CaptureStderr();
@@ -210,12 +182,11 @@ TEST(VPVerifierTest, DuplicateSuccessorsInsideRegion) {
   EXPECT_STREQ("Multiple instances of the same successor.\n",
                ::testing::internal::GetCapturedStderr().c_str());
 #endif
-  delete ScalarHeader;
 }
 
-TEST(VPVerifierTest, BlockOutsideRegionWithParent) {
-  VPBasicBlock *VPPH = new VPBasicBlock("ph");
-  VPBasicBlock *VPBB1 = new VPBasicBlock();
+TEST_F(VPVerifierTest, BlockOutsideRegionWithParent) {
+  VPlan &Plan = getPlan();
+  VPBasicBlock *VPBB1 = Plan.getEntry();
   VPBasicBlock *VPBB2 = new VPBasicBlock();
 
   VPInstruction *DefI = new VPInstruction(Instruction::Add, {});
@@ -228,12 +199,7 @@ TEST(VPVerifierTest, BlockOutsideRegionWithParent) {
   VPRegionBlock *R1 = new VPRegionBlock(VPBB2, VPBB2, "R1");
   VPBlockUtils::connectBlocks(VPBB1, R1);
 
-  LLVMContext C;
-  auto *ScalarHeader = BasicBlock::Create(C, "");
-  VPIRBasicBlock *ScalarHeaderVPBB = new VPIRBasicBlock(ScalarHeader);
-  VPBlockUtils::connectBlocks(R1, ScalarHeaderVPBB);
-  VPBlockUtils::connectBlocks(VPPH, VPBB1);
-  VPlan Plan(VPPH, ScalarHeaderVPBB);
+  VPBlockUtils::connectBlocks(R1, Plan.getScalarHeader());
   VPBB1->setParent(R1);
 
 #if GTEST_HAS_STREAM_REDIRECTION
@@ -244,7 +210,6 @@ TEST(VPVerifierTest, BlockOutsideRegionWithParent) {
   EXPECT_STREQ("Predecessor is not in the same region.\n",
                ::testing::internal::GetCapturedStderr().c_str());
 #endif
-  delete ScalarHeader;
 }
 
 } // namespace

>From af48fccd3d35851f25890f7844aebf2c57a566cc Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Fri, 27 Dec 2024 11:25:59 +0000
Subject: [PATCH 4/6] [VPlan] Funnel

---
 llvm/lib/Transforms/Vectorize/VPlan.h         |  5 +-
 .../Transforms/Vectorize/VPlanHCFGBuilder.cpp |  6 +-
 .../Transforms/Vectorize/VPDomTreeTest.cpp    | 35 +++++----
 .../Transforms/Vectorize/VPlanTest.cpp        | 77 ++++++++++---------
 .../Vectorize/VPlanVerifierTest.cpp           | 18 ++---
 5 files changed, 72 insertions(+), 69 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index beabcc7fd41874..f235ed37e2c9b0 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -3844,7 +3844,6 @@ class VPlan {
   /// VPlan is destroyed.
   SmallVector<VPBlockBase *> CreatedBlocks;
 
-public:
   /// Construct a VPlan with \p Entry to the plan and with \p ScalarHeader
   /// wrapping the original header of the scalar loop.
   VPlan(VPBasicBlock *Entry, VPIRBasicBlock *ScalarHeader)
@@ -3861,8 +3860,8 @@ class VPlan {
   VPlan(Loop *L);
 
   VPlan(BasicBlock *ScalarHeaderBB, VPValue *TC) {
-    setEntry(new VPBasicBlock("preheader"));
-    ScalarHeader = VPIRBasicBlock::fromBasicBlock(ScalarHeaderBB);
+    setEntry(createVPBasicBlock("preheader"));
+    ScalarHeader = createVPIRBasicBlock(ScalarHeaderBB);
     TripCount = TC;
   }
 
diff --git a/llvm/lib/Transforms/Vectorize/VPlanHCFGBuilder.cpp b/llvm/lib/Transforms/Vectorize/VPlanHCFGBuilder.cpp
index 02f4c8d8872d82..76ed578424dfec 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanHCFGBuilder.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanHCFGBuilder.cpp
@@ -357,12 +357,10 @@ void PlainCFGBuilder::buildPlainCFG() {
   BB2VPBB[TheLoop->getHeader()] = VectorHeaderVPBB;
   VectorHeaderVPBB->clearSuccessors();
   VectorLatchVPBB->clearPredecessors();
-  if (TheLoop->getHeader() != TheLoop->getLoopLatch()) {
+  if (TheLoop->getHeader() != TheLoop->getLoopLatch())
     BB2VPBB[TheLoop->getLoopLatch()] = VectorLatchVPBB;
-  } else {
+  else
     TheRegion->setExiting(VectorHeaderVPBB);
-    delete VectorLatchVPBB;
-  }
 
   // 1. Scan the body of the loop in a topological order to visit each basic
   // block after having visited its predecessor basic blocks. Create a VPBB for
diff --git a/llvm/unittests/Transforms/Vectorize/VPDomTreeTest.cpp b/llvm/unittests/Transforms/Vectorize/VPDomTreeTest.cpp
index 6aa34a5fa431b5..4e1415fa7ac135 100644
--- a/llvm/unittests/Transforms/Vectorize/VPDomTreeTest.cpp
+++ b/llvm/unittests/Transforms/Vectorize/VPDomTreeTest.cpp
@@ -29,10 +29,10 @@ TEST_F(VPDominatorTreeTest, DominanceNoRegionsTest) {
   //  }
   VPlan &Plan = getPlan();
   VPBasicBlock *VPBB0 = Plan.getEntry();
-  VPBasicBlock *VPBB1 = new VPBasicBlock("VPBB1");
-  VPBasicBlock *VPBB2 = new VPBasicBlock("VPBB2");
-  VPBasicBlock *VPBB3 = new VPBasicBlock("VPBB3");
-  VPBasicBlock *VPBB4 = new VPBasicBlock("VPBB4");
+  VPBasicBlock *VPBB1 = Plan.createVPBasicBlock("VPBB1");
+  VPBasicBlock *VPBB2 = Plan.createVPBasicBlock("VPBB2");
+  VPBasicBlock *VPBB3 = Plan.createVPBasicBlock("VPBB3");
+  VPBasicBlock *VPBB4 = Plan.createVPBasicBlock("VPBB4");
   VPRegionBlock *R1 = new VPRegionBlock(VPBB1, VPBB4);
   VPBB2->setParent(R1);
   VPBB3->setParent(R1);
@@ -96,10 +96,10 @@ TEST_F(VPDominatorTreeTest, DominanceRegionsTest) {
     //
     VPlan &Plan = getPlan();
     VPBasicBlock *VPBB0 = Plan.getEntry();
-    VPBasicBlock *R1BB1 = new VPBasicBlock();
-    VPBasicBlock *R1BB2 = new VPBasicBlock();
-    VPBasicBlock *R1BB3 = new VPBasicBlock();
-    VPBasicBlock *R1BB4 = new VPBasicBlock();
+    VPBasicBlock *R1BB1 = Plan.createVPBasicBlock("");
+    VPBasicBlock *R1BB2 = Plan.createVPBasicBlock("");
+    VPBasicBlock *R1BB3 = Plan.createVPBasicBlock("");
+    VPBasicBlock *R1BB4 = Plan.createVPBasicBlock("");
     VPRegionBlock *R1 = new VPRegionBlock(R1BB1, R1BB4, "R1");
     R1BB2->setParent(R1);
     R1BB3->setParent(R1);
@@ -111,8 +111,8 @@ TEST_F(VPDominatorTreeTest, DominanceRegionsTest) {
     // Cycle.
     VPBlockUtils::connectBlocks(R1BB3, R1BB3);
 
-    VPBasicBlock *R2BB1 = new VPBasicBlock();
-    VPBasicBlock *R2BB2 = new VPBasicBlock();
+    VPBasicBlock *R2BB1 = Plan.createVPBasicBlock("");
+    VPBasicBlock *R2BB2 = Plan.createVPBasicBlock("");
     VPRegionBlock *R2 = new VPRegionBlock(R2BB1, R2BB2, "R2");
     VPBlockUtils::connectBlocks(R2BB1, R2BB2);
     VPBlockUtils::connectBlocks(R1, R2);
@@ -170,14 +170,14 @@ TEST_F(VPDominatorTreeTest, DominanceRegionsTest) {
     //  VPBB2
     //
     VPlan &Plan = getPlan();
-    VPBasicBlock *R1BB1 = new VPBasicBlock("R1BB1");
-    VPBasicBlock *R1BB2 = new VPBasicBlock("R1BB2");
-    VPBasicBlock *R1BB3 = new VPBasicBlock("R1BB3");
+    VPBasicBlock *R1BB1 = Plan.createVPBasicBlock("R1BB1");
+    VPBasicBlock *R1BB2 = Plan.createVPBasicBlock("R1BB2");
+    VPBasicBlock *R1BB3 = Plan.createVPBasicBlock("R1BB3");
     VPRegionBlock *R1 = new VPRegionBlock(R1BB1, R1BB3, "R1");
 
-    VPBasicBlock *R2BB1 = new VPBasicBlock("R2BB1");
-    VPBasicBlock *R2BB2 = new VPBasicBlock("R2BB2");
-    VPBasicBlock *R2BB3 = new VPBasicBlock("R2BB3");
+    VPBasicBlock *R2BB1 = Plan.createVPBasicBlock("");
+    VPBasicBlock *R2BB2 = Plan.createVPBasicBlock("");
+    VPBasicBlock *R2BB3 = Plan.createVPBasicBlock("");
     VPRegionBlock *R2 = new VPRegionBlock(R2BB1, R2BB3, "R2");
     R2BB2->setParent(R2);
     VPBlockUtils::connectBlocks(R2BB1, R2BB2);
@@ -193,7 +193,8 @@ TEST_F(VPDominatorTreeTest, DominanceRegionsTest) {
 
     VPBasicBlock *VPBB1 = Plan.getEntry();
     VPBlockUtils::connectBlocks(VPBB1, R1);
-    VPBasicBlock *VPBB2 = new VPBasicBlock("VPBB2");
+    VPBasicBlock *VPBB2 = Plan.createVPBasicBlock(""
+                                                  "VPBB2");
     VPBlockUtils::connectBlocks(R1, VPBB2);
 
     VPBlockUtils::connectBlocks(VPBB2, Plan.getScalarHeader());
diff --git a/llvm/unittests/Transforms/Vectorize/VPlanTest.cpp b/llvm/unittests/Transforms/Vectorize/VPlanTest.cpp
index 2ab55f64a20730..69283034991311 100644
--- a/llvm/unittests/Transforms/Vectorize/VPlanTest.cpp
+++ b/llvm/unittests/Transforms/Vectorize/VPlanTest.cpp
@@ -245,9 +245,9 @@ TEST_F(VPBasicBlockTest, getPlan) {
   {
     VPlan &Plan = getPlan();
     VPBasicBlock *VPBB1 = Plan.getEntry();
-    VPBasicBlock *VPBB2 = new VPBasicBlock();
-    VPBasicBlock *VPBB3 = new VPBasicBlock();
-    VPBasicBlock *VPBB4 = new VPBasicBlock();
+    VPBasicBlock *VPBB2 = Plan.createVPBasicBlock("");
+    VPBasicBlock *VPBB3 = Plan.createVPBasicBlock("");
+    VPBasicBlock *VPBB4 = Plan.createVPBasicBlock("");
 
     //     VPBB1
     //     /   \
@@ -270,8 +270,8 @@ TEST_F(VPBasicBlockTest, getPlan) {
     VPlan &Plan = getPlan();
     VPBasicBlock *VPBB1 = Plan.getEntry();
     // VPBasicBlock is the entry into the VPlan, followed by a region.
-    VPBasicBlock *R1BB1 = new VPBasicBlock();
-    VPBasicBlock *R1BB2 = new VPBasicBlock();
+    VPBasicBlock *R1BB1 = Plan.createVPBasicBlock("");
+    VPBasicBlock *R1BB2 = Plan.createVPBasicBlock("");
     VPRegionBlock *R1 = new VPRegionBlock(R1BB1, R1BB2, "R1");
     VPBlockUtils::connectBlocks(R1BB1, R1BB2);
 
@@ -287,13 +287,13 @@ TEST_F(VPBasicBlockTest, getPlan) {
 
   {
     VPlan &Plan = getPlan();
-    VPBasicBlock *R1BB1 = new VPBasicBlock();
-    VPBasicBlock *R1BB2 = new VPBasicBlock();
+    VPBasicBlock *R1BB1 = Plan.createVPBasicBlock("");
+    VPBasicBlock *R1BB2 = Plan.createVPBasicBlock("");
     VPRegionBlock *R1 = new VPRegionBlock(R1BB1, R1BB2, "R1");
     VPBlockUtils::connectBlocks(R1BB1, R1BB2);
 
-    VPBasicBlock *R2BB1 = new VPBasicBlock();
-    VPBasicBlock *R2BB2 = new VPBasicBlock();
+    VPBasicBlock *R2BB1 = Plan.createVPBasicBlock("");
+    VPBasicBlock *R2BB2 = Plan.createVPBasicBlock("");
     VPRegionBlock *R2 = new VPRegionBlock(R2BB1, R2BB2, "R2");
     VPBlockUtils::connectBlocks(R2BB1, R2BB2);
 
@@ -301,7 +301,7 @@ TEST_F(VPBasicBlockTest, getPlan) {
     VPBlockUtils::connectBlocks(VPBB1, R1);
     VPBlockUtils::connectBlocks(VPBB1, R2);
 
-    VPBasicBlock *VPBB2 = new VPBasicBlock();
+    VPBasicBlock *VPBB2 = Plan.createVPBasicBlock("");
     VPBlockUtils::connectBlocks(R1, VPBB2);
     VPBlockUtils::connectBlocks(R2, VPBB2);
 
@@ -329,9 +329,9 @@ TEST_F(VPBasicBlockTest, TraversingIteratorTest) {
     //
     VPlan &Plan = getPlan();
     VPBasicBlock *VPBB1 = Plan.getEntry();
-    VPBasicBlock *VPBB2 = new VPBasicBlock();
-    VPBasicBlock *VPBB3 = new VPBasicBlock();
-    VPBasicBlock *VPBB4 = new VPBasicBlock();
+    VPBasicBlock *VPBB2 = Plan.createVPBasicBlock("");
+    VPBasicBlock *VPBB3 = Plan.createVPBasicBlock("");
+    VPBasicBlock *VPBB4 = Plan.createVPBasicBlock("");
 
     VPBlockUtils::connectBlocks(VPBB1, VPBB2);
     VPBlockUtils::connectBlocks(VPBB1, VPBB3);
@@ -368,10 +368,10 @@ TEST_F(VPBasicBlockTest, TraversingIteratorTest) {
     //
     VPlan &Plan = getPlan();
     VPBasicBlock *VPBB0 = Plan.getEntry();
-    VPBasicBlock *R1BB1 = new VPBasicBlock();
-    VPBasicBlock *R1BB2 = new VPBasicBlock();
-    VPBasicBlock *R1BB3 = new VPBasicBlock();
-    VPBasicBlock *R1BB4 = new VPBasicBlock();
+    VPBasicBlock *R1BB1 = Plan.createVPBasicBlock("");
+    VPBasicBlock *R1BB2 = Plan.createVPBasicBlock("");
+    VPBasicBlock *R1BB3 = Plan.createVPBasicBlock("");
+    VPBasicBlock *R1BB4 = Plan.createVPBasicBlock("");
     VPRegionBlock *R1 = new VPRegionBlock(R1BB1, R1BB4, "R1");
     R1BB2->setParent(R1);
     R1BB3->setParent(R1);
@@ -383,8 +383,8 @@ TEST_F(VPBasicBlockTest, TraversingIteratorTest) {
     // Cycle.
     VPBlockUtils::connectBlocks(R1BB3, R1BB3);
 
-    VPBasicBlock *R2BB1 = new VPBasicBlock();
-    VPBasicBlock *R2BB2 = new VPBasicBlock();
+    VPBasicBlock *R2BB1 = Plan.createVPBasicBlock("");
+    VPBasicBlock *R2BB2 = Plan.createVPBasicBlock("");
     VPRegionBlock *R2 = new VPRegionBlock(R2BB1, R2BB2, "R2");
     VPBlockUtils::connectBlocks(R2BB1, R2BB2);
     VPBlockUtils::connectBlocks(R1, R2);
@@ -467,14 +467,17 @@ TEST_F(VPBasicBlockTest, TraversingIteratorTest) {
     //  VPBB2
     //
     VPlan &Plan = getPlan();
-    VPBasicBlock *R1BB1 = new VPBasicBlock("R1BB1");
-    VPBasicBlock *R1BB2 = new VPBasicBlock("R1BB2");
-    VPBasicBlock *R1BB3 = new VPBasicBlock("R1BB3");
+    VPBasicBlock *R1BB1 = Plan.createVPBasicBlock("R1BB1");
+    VPBasicBlock *R1BB2 = Plan.createVPBasicBlock("R1BB2");
+    VPBasicBlock *R1BB3 = Plan.createVPBasicBlock("R1BB3");
     VPRegionBlock *R1 = new VPRegionBlock(R1BB1, R1BB3, "R1");
 
-    VPBasicBlock *R2BB1 = new VPBasicBlock("R2BB1");
-    VPBasicBlock *R2BB2 = new VPBasicBlock("R2BB2");
-    VPBasicBlock *R2BB3 = new VPBasicBlock("R2BB3");
+    VPBasicBlock *R2BB1 = Plan.createVPBasicBlock(""
+                                                  "R2BB1");
+    VPBasicBlock *R2BB2 = Plan.createVPBasicBlock(""
+                                                  "R2BB2");
+    VPBasicBlock *R2BB3 = Plan.createVPBasicBlock(""
+                                                  "R2BB3");
     VPRegionBlock *R2 = new VPRegionBlock(R2BB1, R2BB3, "R2");
     R2BB2->setParent(R2);
     VPBlockUtils::connectBlocks(R2BB1, R2BB2);
@@ -490,7 +493,8 @@ TEST_F(VPBasicBlockTest, TraversingIteratorTest) {
 
     VPBasicBlock *VPBB1 = Plan.getEntry();
     VPBlockUtils::connectBlocks(VPBB1, R1);
-    VPBasicBlock *VPBB2 = new VPBasicBlock("VPBB2");
+    VPBasicBlock *VPBB2 = Plan.createVPBasicBlock(""
+                                                  "VPBB2");
     VPBlockUtils::connectBlocks(R1, VPBB2);
 
     // Depth-first.
@@ -538,8 +542,8 @@ TEST_F(VPBasicBlockTest, TraversingIteratorTest) {
     //   }
     //
     VPlan &Plan = getPlan();
-    VPBasicBlock *R2BB1 = new VPBasicBlock("R2BB1");
-    VPBasicBlock *R2BB2 = new VPBasicBlock("R2BB2");
+    VPBasicBlock *R2BB1 = Plan.createVPBasicBlock("R2BB1");
+    VPBasicBlock *R2BB2 = Plan.createVPBasicBlock("R2BB2");
     VPRegionBlock *R2 = new VPRegionBlock(R2BB1, R2BB2, "R2");
     VPBlockUtils::connectBlocks(R2BB1, R2BB2);
 
@@ -592,10 +596,11 @@ TEST_F(VPBasicBlockTest, TraversingIteratorTest) {
     //  VPBB2
     //
     VPlan &Plan = getPlan();
-    VPBasicBlock *R3BB1 = new VPBasicBlock("R3BB1");
+    VPBasicBlock *R3BB1 = Plan.createVPBasicBlock("R3BB1");
     VPRegionBlock *R3 = new VPRegionBlock(R3BB1, R3BB1, "R3");
 
-    VPBasicBlock *R2BB1 = new VPBasicBlock("R2BB1");
+    VPBasicBlock *R2BB1 = Plan.createVPBasicBlock(""
+                                                  "R2BB1");
     VPRegionBlock *R2 = new VPRegionBlock(R2BB1, R3, "R2");
     R3->setParent(R2);
     VPBlockUtils::connectBlocks(R2BB1, R3);
@@ -604,7 +609,7 @@ TEST_F(VPBasicBlockTest, TraversingIteratorTest) {
     R2->setParent(R1);
 
     VPBasicBlock *VPBB1 = Plan.getEntry();
-    VPBasicBlock *VPBB2 = new VPBasicBlock("VPBB2");
+    VPBasicBlock *VPBB2 = Plan.createVPBasicBlock("VPBB2");
     VPBlockUtils::connectBlocks(VPBB1, R1);
     VPBlockUtils::connectBlocks(R1, VPBB2);
 
@@ -674,7 +679,7 @@ TEST_F(VPBasicBlockTest, print) {
   VPInstruction *I2 = new VPInstruction(Instruction::Sub, {I1});
   VPInstruction *I3 = new VPInstruction(Instruction::Br, {I1, I2});
 
-  VPBasicBlock *VPBB1 = new VPBasicBlock();
+  VPBasicBlock *VPBB1 = Plan.createVPBasicBlock("");
   VPBB1->appendRecipe(I1);
   VPBB1->appendRecipe(I2);
   VPBB1->appendRecipe(I3);
@@ -682,7 +687,7 @@ TEST_F(VPBasicBlockTest, print) {
 
   VPInstruction *I4 = new VPInstruction(Instruction::Mul, {I2, I1});
   VPInstruction *I5 = new VPInstruction(Instruction::Ret, {I4});
-  VPBasicBlock *VPBB2 = new VPBasicBlock();
+  VPBasicBlock *VPBB2 = Plan.createVPBasicBlock("");
   VPBB2->appendRecipe(I4);
   VPBB2->appendRecipe(I5);
   VPBB2->setName("bb2");
@@ -783,7 +788,7 @@ TEST_F(VPBasicBlockTest, printPlanWithVFsAndUFs) {
   VPBB0->appendRecipe(TC);
 
   VPInstruction *I1 = new VPInstruction(Instruction::Add, {});
-  VPBasicBlock *VPBB1 = new VPBasicBlock();
+  VPBasicBlock *VPBB1 = Plan.createVPBasicBlock("");
   VPBB1->appendRecipe(I1);
   VPBB1->setName("bb1");
 
@@ -1238,7 +1243,7 @@ TEST_F(VPRecipeTest, MayHaveSideEffectsAndMayReadWriteMemory) {
 TEST_F(VPRecipeTest, dumpRecipeInPlan) {
   VPlan &Plan = getPlan();
   VPBasicBlock *VPBB0 = Plan.getEntry();
-  VPBasicBlock *VPBB1 = new VPBasicBlock();
+  VPBasicBlock *VPBB1 = Plan.createVPBasicBlock("");
   VPBlockUtils::connectBlocks(VPBB1, Plan.getScalarHeader());
   VPBlockUtils::connectBlocks(VPBB0, VPBB1);
 
@@ -1307,7 +1312,7 @@ TEST_F(VPRecipeTest, dumpRecipeInPlan) {
 TEST_F(VPRecipeTest, dumpRecipeUnnamedVPValuesInPlan) {
   VPlan &Plan = getPlan();
   VPBasicBlock *VPBB0 = Plan.getEntry();
-  VPBasicBlock *VPBB1 = new VPBasicBlock();
+  VPBasicBlock *VPBB1 = Plan.createVPBasicBlock("");
   VPBlockUtils::connectBlocks(VPBB1, Plan.getScalarHeader());
   VPBlockUtils::connectBlocks(VPBB0, VPBB1);
 
diff --git a/llvm/unittests/Transforms/Vectorize/VPlanVerifierTest.cpp b/llvm/unittests/Transforms/Vectorize/VPlanVerifierTest.cpp
index 174249a7e85e32..5a29e7ac0893b6 100644
--- a/llvm/unittests/Transforms/Vectorize/VPlanVerifierTest.cpp
+++ b/llvm/unittests/Transforms/Vectorize/VPlanVerifierTest.cpp
@@ -27,7 +27,7 @@ TEST_F(VPVerifierTest, VPInstructionUseBeforeDefSameBB) {
   VPBB1->appendRecipe(UseI);
   VPBB1->appendRecipe(DefI);
 
-  VPBasicBlock *VPBB2 = new VPBasicBlock();
+  VPBasicBlock *VPBB2 = Plan.createVPBasicBlock("");
   VPRegionBlock *R1 = new VPRegionBlock(VPBB2, VPBB2, "R1");
   VPBlockUtils::connectBlocks(VPBB1, R1);
   VPBlockUtils::connectBlocks(R1, Plan.getScalarHeader());
@@ -51,7 +51,7 @@ TEST_F(VPVerifierTest, VPInstructionUseBeforeDefDifferentBB) {
       new VPInstruction(VPInstruction::BranchOnCond, {CanIV});
 
   VPBasicBlock *VPBB1 = Plan.getEntry();
-  VPBasicBlock *VPBB2 = new VPBasicBlock();
+  VPBasicBlock *VPBB2 = Plan.createVPBasicBlock("");
 
   VPBB1->appendRecipe(UseI);
   VPBB2->appendRecipe(CanIV);
@@ -85,9 +85,9 @@ TEST_F(VPVerifierTest, VPBlendUseBeforeDefDifferentBB) {
 
   VPlan &Plan = getPlan();
   VPBasicBlock *VPBB1 = Plan.getEntry();
-  VPBasicBlock *VPBB2 = new VPBasicBlock();
-  VPBasicBlock *VPBB3 = new VPBasicBlock();
-  VPBasicBlock *VPBB4 = new VPBasicBlock();
+  VPBasicBlock *VPBB2 = Plan.createVPBasicBlock("");
+  VPBasicBlock *VPBB3 = Plan.createVPBasicBlock("");
+  VPBasicBlock *VPBB4 = Plan.createVPBasicBlock("");
 
   VPBB1->appendRecipe(I1);
   VPBB2->appendRecipe(CanIV);
@@ -125,7 +125,7 @@ TEST_F(VPVerifierTest, DuplicateSuccessorsOutsideRegion) {
 
   VPlan &Plan = getPlan();
   VPBasicBlock *VPBB1 = Plan.getEntry();
-  VPBasicBlock *VPBB2 = new VPBasicBlock();
+  VPBasicBlock *VPBB2 = Plan.createVPBasicBlock("");
 
   VPBB1->appendRecipe(I1);
   VPBB1->appendRecipe(BranchOnCond2);
@@ -158,8 +158,8 @@ TEST_F(VPVerifierTest, DuplicateSuccessorsInsideRegion) {
 
   VPlan &Plan = getPlan();
   VPBasicBlock *VPBB1 = Plan.getEntry();
-  VPBasicBlock *VPBB2 = new VPBasicBlock();
-  VPBasicBlock *VPBB3 = new VPBasicBlock();
+  VPBasicBlock *VPBB2 = Plan.createVPBasicBlock("");
+  VPBasicBlock *VPBB3 = Plan.createVPBasicBlock("");
 
   VPBB1->appendRecipe(I1);
   VPBB2->appendRecipe(CanIV);
@@ -187,7 +187,7 @@ TEST_F(VPVerifierTest, DuplicateSuccessorsInsideRegion) {
 TEST_F(VPVerifierTest, BlockOutsideRegionWithParent) {
   VPlan &Plan = getPlan();
   VPBasicBlock *VPBB1 = Plan.getEntry();
-  VPBasicBlock *VPBB2 = new VPBasicBlock();
+  VPBasicBlock *VPBB2 = Plan.createVPBasicBlock("");
 
   VPInstruction *DefI = new VPInstruction(Instruction::Add, {});
   VPInstruction *BranchOnCond =

>From fc34ca56164a75668a1dd5477e34acccb76e9708 Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Sun, 29 Dec 2024 20:56:19 +0000
Subject: [PATCH 5/6] !fixup address comments, thansk!

---
 .../Transforms/Vectorize/LoopVectorize.cpp    |  1 +
 llvm/lib/Transforms/Vectorize/VPlan.cpp       | 62 +++++++++----------
 llvm/lib/Transforms/Vectorize/VPlan.h         | 13 ++--
 .../Transforms/Vectorize/VPlanTransforms.cpp  |  8 +--
 .../Transforms/Vectorize/VPlanTest.cpp        |  3 +-
 5 files changed, 39 insertions(+), 48 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 5a215a530b2767..f38db39db9cffd 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -8088,6 +8088,7 @@ EpilogueVectorizerEpilogueLoop::emitMinimumVectorEpilogueIterCountCheck(
   VPBasicBlock *OldEntry = Plan.getEntry();
   VPBlockUtils::reassociateBlocks(OldEntry, NewEntry);
   Plan.setEntry(NewEntry);
+  // OldEntry is now dead and will be cleaned up when the plan gets destroyed.
 
   introduceCheckBlockInVPlan(Plan, Insert);
   return Insert;
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.cpp b/llvm/lib/Transforms/Vectorize/VPlan.cpp
index 204a1e01b9313c..2cf2022c338250 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlan.cpp
@@ -470,10 +470,7 @@ void VPIRBasicBlock::execute(VPTransformState *State) {
 }
 
 VPIRBasicBlock *VPIRBasicBlock::clone() {
-  auto *NewBlock = getPlan()->createVPIRBasicBlock(IRBB);
-  for (VPRecipeBase &R : make_early_inc_range(*NewBlock))
-    R.eraseFromParent();
-
+  auto *NewBlock = getPlan()->createEmptyVPIRBasicBlock(IRBB);
   for (VPRecipeBase &R : Recipes)
     NewBlock->appendRecipe(R.clone());
   return NewBlock;
@@ -518,16 +515,6 @@ void VPBasicBlock::execute(VPTransformState *State) {
   executeRecipes(State, NewBB);
 }
 
-void VPBasicBlock::dropAllReferences(VPValue *NewValue) {
-  for (VPRecipeBase &R : Recipes) {
-    for (auto *Def : R.definedValues())
-      Def->replaceAllUsesWith(NewValue);
-
-    for (unsigned I = 0, E = R.getNumOperands(); I != E; I++)
-      R.setOperand(I, NewValue);
-  }
-}
-
 VPBasicBlock *VPBasicBlock::clone() {
   auto *NewBlock = getPlan()->createVPBasicBlock(getName());
   for (VPRecipeBase &R : *this)
@@ -720,13 +707,6 @@ VPRegionBlock *VPRegionBlock::clone() {
   return NewRegion;
 }
 
-void VPRegionBlock::dropAllReferences(VPValue *NewValue) {
-  for (VPBlockBase *Block : vp_depth_first_shallow(Entry))
-    // Drop all references in VPBasicBlocks and replace all uses with
-    // DummyValue.
-    Block->dropAllReferences(NewValue);
-}
-
 void VPRegionBlock::execute(VPTransformState *State) {
   ReversePostOrderTraversal<VPBlockShallowTraversalWrapper<VPBlockBase *>>
       RPOT(Entry);
@@ -839,16 +819,22 @@ VPlan::VPlan(Loop *L) {
 }
 
 VPlan::~VPlan() {
-  if (Entry) {
     VPValue DummyValue;
 
-    for (auto *VPB : reverse(CreatedBlocks))
-      VPB->dropAllReferences(&DummyValue);
-
-    for (auto *VPB : reverse(CreatedBlocks)) {
+    for (auto *VPB : CreatedBlocks) {
+      if (auto *VPBB = dyn_cast<VPBasicBlock>(VPB)) {
+        // Replace all operands of recipes and all VPValues define in VPBB with
+        // DummyValue so the block can be deleted.
+        for (VPRecipeBase &R : *VPBB) {
+          for (auto *Def : R.definedValues())
+            Def->replaceAllUsesWith(&DummyValue);
+
+          for (unsigned I = 0, E = R.getNumOperands(); I != E; I++)
+            R.setOperand(I, &DummyValue);
+        }
+      }
       delete VPB;
     }
-  }
   for (VPValue *VPV : VPLiveInsToFree)
     delete VPV;
   if (BackedgeTakenCount)
@@ -1222,7 +1208,7 @@ static void remapOperands(VPBlockBase *Entry, VPBlockBase *NewEntry,
 }
 
 VPlan *VPlan::duplicate() {
-  unsigned CreatedBlockSize = CreatedBlocks.size();
+  unsigned NumBlocksBeforeCloning = CreatedBlocks.size();
   // Clone blocks.
   const auto &[NewEntry, __] = cloneFrom(Entry);
 
@@ -1264,20 +1250,28 @@ VPlan *VPlan::duplicate() {
          "TripCount must have been added to Old2NewVPValues");
   NewPlan->TripCount = Old2NewVPValues[TripCount];
 
-  // Transfer cloned blocks to new VPlan.
-  for (unsigned I : seq<unsigned>(CreatedBlockSize, CreatedBlocks.size()))
-    NewPlan->CreatedBlocks.push_back(CreatedBlocks[I]);
-  CreatedBlocks.truncate(CreatedBlockSize);
+  // Transfer all cloned blocks (the second half of all current blocks) from
+  // current to new VPlan.
+  unsigned NumBlocksAfterCloning = CreatedBlocks.size();
+  for (unsigned I :
+       seq<unsigned>(NumBlocksBeforeCloning, NumBlocksAfterCloning))
+    NewPlan->CreatedBlocks.push_back(this->CreatedBlocks[I]);
+  CreatedBlocks.truncate(NumBlocksBeforeCloning);
 
   return NewPlan;
 }
 
-VPIRBasicBlock *VPlan::createVPIRBasicBlock(BasicBlock *IRBB) {
+VPIRBasicBlock *VPlan::createEmptyVPIRBasicBlock(BasicBlock *IRBB) {
   auto *VPIRBB = new VPIRBasicBlock(IRBB);
+  CreatedBlocks.push_back(VPIRBB);
+  return VPIRBB;
+}
+
+VPIRBasicBlock *VPlan::createVPIRBasicBlock(BasicBlock *IRBB) {
+  auto *VPIRBB = createEmptyVPIRBasicBlock(IRBB);
   for (Instruction &I :
        make_range(IRBB->begin(), IRBB->getTerminator()->getIterator()))
     VPIRBB->appendRecipe(new VPIRInstruction(I));
-  CreatedBlocks.push_back(VPIRBB);
   return VPIRBB;
 }
 
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index c094f57e2fabca..aff503e2afd183 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -643,10 +643,6 @@ class VPBlockBase {
     return true;
   }
 
-  /// Replace all operands of VPUsers in the block with \p NewValue and also
-  /// replaces all uses of VPValues defined in the block with NewValue.
-  virtual void dropAllReferences(VPValue *NewValue) = 0;
-
 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
   void printAsOperand(raw_ostream &OS, bool PrintType = false) const {
     OS << getName();
@@ -3553,8 +3549,6 @@ class VPBasicBlock : public VPBlockBase {
     return make_range(begin(), getFirstNonPhi());
   }
 
-  void dropAllReferences(VPValue *NewValue) override;
-
   /// Split current block at \p SplitAt by inserting a new block between the
   /// current block and its successors and moving all recipes starting at
   /// SplitAt to the new block. Returns the new block.
@@ -3711,8 +3705,6 @@ class VPRegionBlock : public VPBlockBase {
   // Return the cost of this region.
   InstructionCost cost(ElementCount VF, VPCostContext &Ctx) override;
 
-  void dropAllReferences(VPValue *NewValue) override;
-
 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
   /// Print this VPRegionBlock to \p O (recursively), prefixing all lines with
   /// \p Indent. \p SlotTracker is used to print unnamed VPValue's using
@@ -4043,6 +4035,11 @@ class VPlan {
     return VPB;
   }
 
+  /// Create a VPIRBasicBlock wrapping \p IRBB, but do not create
+  /// VPIRInstructions wrapping the instructions in t\p IRBB.  The returned
+  /// block is owned by the VPlan and deleted once the VPlan is destroyed.
+  VPIRBasicBlock *createEmptyVPIRBasicBlock(BasicBlock *IRBB);
+
   /// Create a VPIRBasicBlock from \p IRBB containing VPIRInstructions for all
   /// instructions in \p IRBB, except its terminator which is managed in VPlan.
   /// The returned block is owned by the VPlan and deleted once the VPlan is
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index cd3ea561e6aac0..c5695b00fe58dd 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -217,7 +217,7 @@ static VPBasicBlock *getPredicatedThenBlock(VPRegionBlock *R) {
 // is connected to a successor replicate region with the same predicate by a
 // single, empty VPBasicBlock.
 static bool mergeReplicateRegionsIntoSuccessors(VPlan &Plan) {
-  SetVector<VPRegionBlock *> DeletedRegions;
+  SmallPtrSet<VPRegionBlock *, 4> TransformedRegions;
 
   // Collect replicate regions followed by an empty block, followed by another
   // replicate region with matching masks to process front. This is to avoid
@@ -248,7 +248,7 @@ static bool mergeReplicateRegionsIntoSuccessors(VPlan &Plan) {
 
   // Move recipes from Region1 to its successor region, if both are triangles.
   for (VPRegionBlock *Region1 : WorkList) {
-    if (DeletedRegions.contains(Region1))
+    if (TransformedRegions.contains(Region1))
       continue;
     auto *MiddleBasicBlock = cast<VPBasicBlock>(Region1->getSingleSuccessor());
     auto *Region2 = cast<VPRegionBlock>(MiddleBasicBlock->getSingleSuccessor());
@@ -294,10 +294,10 @@ static bool mergeReplicateRegionsIntoSuccessors(VPlan &Plan) {
       VPBlockUtils::connectBlocks(Pred, MiddleBasicBlock);
     }
     VPBlockUtils::disconnectBlocks(Region1, MiddleBasicBlock);
-    DeletedRegions.insert(Region1);
+    TransformedRegions.insert(Region1);
   }
 
-  return !DeletedRegions.empty();
+  return !TransformedRegions.empty();
 }
 
 static VPRegionBlock *createReplicateRegion(VPReplicateRecipe *PredRecipe,
diff --git a/llvm/unittests/Transforms/Vectorize/VPlanTest.cpp b/llvm/unittests/Transforms/Vectorize/VPlanTest.cpp
index 808dc89a3e2f10..bde37670f0fce7 100644
--- a/llvm/unittests/Transforms/Vectorize/VPlanTest.cpp
+++ b/llvm/unittests/Transforms/Vectorize/VPlanTest.cpp
@@ -493,8 +493,7 @@ TEST_F(VPBasicBlockTest, TraversingIteratorTest) {
 
     VPBasicBlock *VPBB1 = Plan.getEntry();
     VPBlockUtils::connectBlocks(VPBB1, R1);
-    VPBasicBlock *VPBB2 = Plan.createVPBasicBlock(""
-                                                  "VPBB2");
+    VPBasicBlock *VPBB2 = Plan.createVPBasicBlock("VPBB2");
     VPBlockUtils::connectBlocks(R1, VPBB2);
 
     // Depth-first.

>From 8f83ad84ebd28c53203b97be3de03c23122ebc67 Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Mon, 30 Dec 2024 10:46:20 +0000
Subject: [PATCH 6/6] !fixup address latest comments, thanks!

---
 llvm/lib/Transforms/Vectorize/VPlan.cpp           | 3 ++-
 llvm/lib/Transforms/Vectorize/VPlan.h             | 6 +++---
 llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp | 1 +
 llvm/unittests/Transforms/Vectorize/VPlanTest.cpp | 9 +++------
 4 files changed, 9 insertions(+), 10 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/VPlan.cpp b/llvm/lib/Transforms/Vectorize/VPlan.cpp
index 2cf2022c338250..1914774d88019f 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlan.cpp
@@ -823,7 +823,7 @@ VPlan::~VPlan() {
 
     for (auto *VPB : CreatedBlocks) {
       if (auto *VPBB = dyn_cast<VPBasicBlock>(VPB)) {
-        // Replace all operands of recipes and all VPValues define in VPBB with
+        // Replace all operands of recipes and all VPValues defined in VPBB with
         // DummyValue so the block can be deleted.
         for (VPRecipeBase &R : *VPBB) {
           for (auto *Def : R.definedValues())
@@ -960,6 +960,7 @@ static void replaceVPBBWithIRVPBB(VPBasicBlock *VPBB, BasicBlock *IRBB) {
   }
 
   VPBlockUtils::reassociateBlocks(VPBB, IRVPBB);
+  // VPBB is now dead and will be cleaned up when the plan gets destroyed.
 }
 
 /// Generate the code inside the preheader and body of the vectorized loop.
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index aff503e2afd183..199e0dd7a6becb 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -4041,9 +4041,9 @@ class VPlan {
   VPIRBasicBlock *createEmptyVPIRBasicBlock(BasicBlock *IRBB);
 
   /// Create a VPIRBasicBlock from \p IRBB containing VPIRInstructions for all
-  /// instructions in \p IRBB, except its terminator which is managed in VPlan.
-  /// The returned block is owned by the VPlan and deleted once the VPlan is
-  /// destroyed.
+  /// instructions in \p IRBB, except its terminator which is managed by the
+  /// successors of the block in VPlan. The returned block is owned by the VPlan
+  /// and deleted once the VPlan is destroyed.
   VPIRBasicBlock *createVPIRBasicBlock(BasicBlock *IRBB);
 };
 
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index c5695b00fe58dd..1f5acf996a7720 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -398,6 +398,7 @@ static bool mergeBlocksIntoPredecessors(VPlan &Plan) {
       VPBlockUtils::disconnectBlocks(VPBB, Succ);
       VPBlockUtils::connectBlocks(PredVPBB, Succ);
     }
+    // VPBB is now dead and will be cleaned up when the plan gets destroyed.
   }
   return !WorkList.empty();
 }
diff --git a/llvm/unittests/Transforms/Vectorize/VPlanTest.cpp b/llvm/unittests/Transforms/Vectorize/VPlanTest.cpp
index bde37670f0fce7..00c6744a44e748 100644
--- a/llvm/unittests/Transforms/Vectorize/VPlanTest.cpp
+++ b/llvm/unittests/Transforms/Vectorize/VPlanTest.cpp
@@ -472,12 +472,9 @@ TEST_F(VPBasicBlockTest, TraversingIteratorTest) {
     VPBasicBlock *R1BB3 = Plan.createVPBasicBlock("R1BB3");
     VPRegionBlock *R1 = Plan.createVPRegionBlock(R1BB1, R1BB3, "R1");
 
-    VPBasicBlock *R2BB1 = Plan.createVPBasicBlock(""
-                                                  "R2BB1");
-    VPBasicBlock *R2BB2 = Plan.createVPBasicBlock(""
-                                                  "R2BB2");
-    VPBasicBlock *R2BB3 = Plan.createVPBasicBlock(""
-                                                  "R2BB3");
+    VPBasicBlock *R2BB1 = Plan.createVPBasicBlock("R2BB1");
+    VPBasicBlock *R2BB2 = Plan.createVPBasicBlock("R2BB2");
+    VPBasicBlock *R2BB3 = Plan.createVPBasicBlock("R2BB3");
     VPRegionBlock *R2 = Plan.createVPRegionBlock(R2BB1, R2BB3, "R2");
     R2BB2->setParent(R2);
     VPBlockUtils::connectBlocks(R2BB1, R2BB2);



More information about the llvm-commits mailing list