[llvm] [VPlan][NFC] Add new getMiddleBlock interface to VPlan (PR #113558)

David Sherwood via llvm-commits llvm-commits at lists.llvm.org
Fri Nov 1 02:57:18 PDT 2024


https://github.com/david-arm updated https://github.com/llvm/llvm-project/pull/113558

>From 2b54b69aba946f3ed108d091930d9bd0f2e8614d Mon Sep 17 00:00:00 2001
From: David Sherwood <david.sherwood at arm.com>
Date: Fri, 1 Nov 2024 09:56:17 +0000
Subject: [PATCH] [VPlan][NFC] Add new getMiddleBlock interface to VPlan

This work is in preparation for PRs #112138 and #88385 where
the middle block is not guaranteed to be the immediate successor
to the region block. I've simply add new getMiddleBlock()
interfaces to VPlan that for now just return

cast<VPBasicBlock>(VectorRegion->getSingleSuccessor())

Once PR #112138 lands we'll need to do more work to discover
the middle block.
---
 llvm/lib/Transforms/Vectorize/LoopVectorize.cpp | 17 ++++++-----------
 llvm/lib/Transforms/Vectorize/VPlan.cpp         |  3 +--
 llvm/lib/Transforms/Vectorize/VPlan.h           | 10 ++++++++++
 3 files changed, 17 insertions(+), 13 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 58fcba93f1a188..659b4c30a58ada 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -7703,8 +7703,7 @@ DenseMap<const SCEV *, Value *> LoopVectorizationPlanner::executePlan(
   BestVPlan.execute(&State);
 
   // 2.5 Collect reduction resume values.
-  auto *ExitVPBB =
-      cast<VPBasicBlock>(BestVPlan.getVectorLoopRegion()->getSingleSuccessor());
+  auto *ExitVPBB = BestVPlan.getMiddleBlock();
   if (VectorizingEpilogue)
     for (VPRecipeBase &R : *ExitVPBB) {
       fixReductionScalarResumeWhenVectorizingEpilog(
@@ -8830,8 +8829,7 @@ static void addScalarResumePhis(VPRecipeBuilder &Builder, VPlan &Plan) {
 static SetVector<VPIRInstruction *> collectUsersInExitBlock(
     Loop *OrigLoop, VPRecipeBuilder &Builder, VPlan &Plan,
     const MapVector<PHINode *, InductionDescriptor> &Inductions) {
-  auto *MiddleVPBB =
-      cast<VPBasicBlock>(Plan.getVectorLoopRegion()->getSingleSuccessor());
+  auto *MiddleVPBB = Plan.getMiddleBlock();
   // No edge from the middle block to the unique exit block has been inserted
   // and there is nothing to fix from vector loop; phis should have incoming
   // from scalar loop only.
@@ -8876,8 +8874,7 @@ addUsersInExitBlock(VPlan &Plan,
   if (ExitUsersToFix.empty())
     return;
 
-  auto *MiddleVPBB =
-      cast<VPBasicBlock>(Plan.getVectorLoopRegion()->getSingleSuccessor());
+  auto *MiddleVPBB = Plan.getMiddleBlock();
   VPBuilder B(MiddleVPBB, MiddleVPBB->getFirstNonPhi());
 
   // Introduce extract for exiting values and update the VPIRInstructions
@@ -8905,7 +8902,7 @@ static void addExitUsersForFirstOrderRecurrences(
     VPlan &Plan, SetVector<VPIRInstruction *> &ExitUsersToFix) {
   VPRegionBlock *VectorRegion = Plan.getVectorLoopRegion();
   auto *ScalarPHVPBB = Plan.getScalarPreheader();
-  auto *MiddleVPBB = cast<VPBasicBlock>(VectorRegion->getSingleSuccessor());
+  auto *MiddleVPBB = Plan.getMiddleBlock();
   VPBuilder ScalarPHBuilder(ScalarPHVPBB);
   VPBuilder MiddleBuilder(MiddleVPBB, MiddleVPBB->getFirstNonPhi());
   VPValue *TwoVPV = Plan.getOrAddLiveIn(
@@ -9085,8 +9082,7 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range) {
         bool NeedsBlends = BB != HeaderBB && !BB->phis().empty();
         return Legal->blockNeedsPredication(BB) || NeedsBlends;
       });
-  auto *MiddleVPBB =
-      cast<VPBasicBlock>(Plan->getVectorLoopRegion()->getSingleSuccessor());
+  auto *MiddleVPBB = Plan->getMiddleBlock();
   VPBasicBlock::iterator MBIP = MiddleVPBB->getFirstNonPhi();
   for (BasicBlock *BB : make_range(DFS.beginRPO(), DFS.endRPO())) {
     // Relevant instructions from basic block BB will be grouped into VPRecipe
@@ -9303,8 +9299,7 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
   using namespace VPlanPatternMatch;
   VPRegionBlock *VectorLoopRegion = Plan->getVectorLoopRegion();
   VPBasicBlock *Header = VectorLoopRegion->getEntryBasicBlock();
-  VPBasicBlock *MiddleVPBB =
-      cast<VPBasicBlock>(VectorLoopRegion->getSingleSuccessor());
+  VPBasicBlock *MiddleVPBB = Plan->getMiddleBlock();
   for (VPRecipeBase &R : Header->phis()) {
     auto *PhiR = dyn_cast<VPReductionPHIRecipe>(&R);
     if (!PhiR || !PhiR->isInLoop() || (MinVF.isScalar() && !PhiR->isOrdered()))
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.cpp b/llvm/lib/Transforms/Vectorize/VPlan.cpp
index 84880bbb19793d..7c06fb2353822b 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlan.cpp
@@ -1034,8 +1034,7 @@ void VPlan::execute(VPTransformState *State) {
   // skeleton creation, so we can only create the VPIRBasicBlocks now during
   // VPlan execution rather than earlier during VPlan construction.
   BasicBlock *MiddleBB = State->CFG.ExitBB;
-  VPBasicBlock *MiddleVPBB =
-      cast<VPBasicBlock>(getVectorLoopRegion()->getSingleSuccessor());
+  VPBasicBlock *MiddleVPBB = getMiddleBlock();
   BasicBlock *ScalarPh = MiddleBB->getSingleSuccessor();
   replaceVPBBWithIRVPBB(getScalarPreheader(), ScalarPh);
   replaceVPBBWithIRVPBB(MiddleVPBB, MiddleBB);
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index 4e5878cae2ddc3..cf4b38b340dc12 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -3748,6 +3748,16 @@ class VPlan {
     return cast<VPBasicBlock>(ScalarHeader->getSinglePredecessor());
   }
 
+  /// Returns the 'middle' block of the plan, that is the block that selects
+  /// whether to execute the scalar tail loop or the exit block from the loop
+  /// latch.
+  const VPBasicBlock *getMiddleBlock() const {
+    return cast<VPBasicBlock>(getVectorLoopRegion()->getSingleSuccessor());
+  }
+  VPBasicBlock *getMiddleBlock() {
+    return cast<VPBasicBlock>(getVectorLoopRegion()->getSingleSuccessor());
+  }
+
   /// The trip count of the original loop.
   VPValue *getTripCount() const {
     assert(TripCount && "trip count needs to be set before accessing it");



More information about the llvm-commits mailing list