[llvm] [VPlan] Add VPIRWrapperBlock, use to model pre-preheader. (PR #93398)

via llvm-commits llvm-commits at lists.llvm.org
Sat May 25 21:50:41 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-transforms

Author: Florian Hahn (fhahn)

<details>
<summary>Changes</summary>

This patch adds a new special type of VPBasicBlock that wraps an existing IR basic block. Recipes of the block get added before the terminator of the wrapped IR basic block. Making it a subclass of VPBasicBlock avoids duplicating various APIs to manage recipes in a block, as well as makes sure the traversals filtering VPBasicBlocks automatically apply as well.

Initially VPIRWrappedBlocks are only used for the pre-preheader (wrapping the original preheader of the scalar loop).

As follow-up, this will be used to move more parts of the skeleton inside VPlan, startingt with the branch and condition in the middle block.

Note: This requires updating all VPlan-printing tests, which I will do once we converge on a final version.

Separated out of https://github.com/llvm/llvm-project/pull/92651

---
Full diff: https://github.com/llvm/llvm-project/pull/93398.diff


5 Files Affected:

- (modified) llvm/lib/Transforms/Vectorize/LoopVectorize.cpp (+2-2) 
- (modified) llvm/lib/Transforms/Vectorize/VPlan.cpp (+55-2) 
- (modified) llvm/lib/Transforms/Vectorize/VPlan.h (+60-5) 
- (modified) llvm/test/Transforms/LoopVectorize/vplan-printing-before-execute.ll (+2-2) 
- (modified) llvm/unittests/Transforms/Vectorize/VPlanTestBase.h (+8-6) 


``````````diff
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 48981a6bd39e3..e71a0df1d9c7c 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -8607,7 +8607,7 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range) {
   // loop region contains a header and latch basic blocks.
   VPlanPtr Plan = VPlan::createInitialVPlan(
       createTripCountSCEV(Legal->getWidestInductionType(), PSE, OrigLoop),
-      *PSE.getSE());
+      *PSE.getSE(), OrigLoop->getLoopPreheader());
   VPBasicBlock *HeaderVPBB = new VPBasicBlock("vector.body");
   VPBasicBlock *LatchVPBB = new VPBasicBlock("vector.latch");
   VPBlockUtils::insertBlockAfter(LatchVPBB, HeaderVPBB);
@@ -8855,7 +8855,7 @@ VPlanPtr LoopVectorizationPlanner::buildVPlan(VFRange &Range) {
   // Create new empty VPlan
   auto Plan = VPlan::createInitialVPlan(
       createTripCountSCEV(Legal->getWidestInductionType(), PSE, OrigLoop),
-      *PSE.getSE());
+      *PSE.getSE(), OrigLoop->getLoopPreheader());
 
   // Build hierarchical CFG
   VPlanHCFGBuilder HCFGBuilder(OrigLoop, LI, *Plan);
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.cpp b/llvm/lib/Transforms/Vectorize/VPlan.cpp
index d71d7580e6ba6..8998e392a433e 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlan.cpp
@@ -442,6 +442,58 @@ VPBasicBlock::createEmptyBasicBlock(VPTransformState::CFGState &CFG) {
   return NewBB;
 }
 
+void VPIRWrapperBlock::execute(VPTransformState *State) {
+  for (VPBlockBase *PredVPBlock : getHierarchicalPredecessors()) {
+    VPBasicBlock *PredVPBB = PredVPBlock->getExitingBasicBlock();
+    auto &PredVPSuccessors = PredVPBB->getHierarchicalSuccessors();
+    BasicBlock *PredBB = State->CFG.VPBB2IRBB[PredVPBB];
+
+    assert(PredBB && "Predecessor basic-block not found building successor.");
+    auto *PredBBTerminator = PredBB->getTerminator();
+    LLVM_DEBUG(dbgs() << "LV: draw edge from" << PredBB->getName() << '\n');
+
+    auto *TermBr = dyn_cast<BranchInst>(PredBBTerminator);
+    if (TermBr) {
+      // Set each forward successor here when it is created, excluding
+      // backedges. A backward successor is set when the branch is created.
+      unsigned idx = PredVPSuccessors.front() == this ? 0 : 1;
+      assert(!TermBr->getSuccessor(idx) &&
+             "Trying to reset an existing successor block.");
+      TermBr->setSuccessor(idx, WrappedBlock);
+    }
+  }
+
+  assert(getHierarchicalSuccessors().size() == 0 &&
+         "VPIRWrapperBlock cannot have successors");
+  State->CFG.VPBB2IRBB[this] = getWrappedBlock();
+  State->CFG.PrevVPBB = this;
+
+  auto *Term = cast<BranchInst>(getWrappedBlock()->getTerminator());
+  State->Builder.SetInsertPoint(Term);
+
+  for (VPRecipeBase &Recipe : *this)
+    Recipe.execute(*State);
+
+  LLVM_DEBUG(dbgs() << "LV: filled BB:" << *getWrappedBlock());
+}
+
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+
+void VPIRWrapperBlock::print(raw_ostream &O, const Twine &Indent,
+                             VPSlotTracker &SlotTracker) const {
+  O << Indent << "ir-bb<" << getName() << ">:\n";
+
+  auto RecipeIndent = Indent + "  ";
+  for (const VPRecipeBase &Recipe : *this) {
+    Recipe.print(O, RecipeIndent, SlotTracker);
+    O << '\n';
+  }
+  assert(getSuccessors().empty() &&
+         "Wrapper blocks should not have successors");
+  printSuccessors(O, Indent);
+}
+#endif
+
 void VPBasicBlock::execute(VPTransformState *State) {
   bool Replica = State->Instance && !State->Instance->isFirstIteration();
   VPBasicBlock *PrevVPBB = State->CFG.PrevVPBB;
@@ -769,8 +821,9 @@ VPlan::~VPlan() {
     delete BackedgeTakenCount;
 }
 
-VPlanPtr VPlan::createInitialVPlan(const SCEV *TripCount, ScalarEvolution &SE) {
-  VPBasicBlock *Preheader = new VPBasicBlock("ph");
+VPlanPtr VPlan::createInitialVPlan(const SCEV *TripCount, ScalarEvolution &SE,
+                                   BasicBlock *PH) {
+  VPIRWrapperBlock *Preheader = new VPIRWrapperBlock(PH);
   VPBasicBlock *VecPreheader = new VPBasicBlock("vector.ph");
   auto Plan = std::make_unique<VPlan>(Preheader, VecPreheader);
   Plan->TripCount =
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index 3aee17921086d..20a12b571d0c0 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -473,7 +473,11 @@ class VPBlockBase {
   /// that are actually instantiated. Values of this enumeration are kept in the
   /// SubclassID field of the VPBlockBase objects. They are used for concrete
   /// type identification.
-  using VPBlockTy = enum { VPBasicBlockSC, VPRegionBlockSC };
+  using VPBlockTy = enum {
+    VPBasicBlockSC,
+    VPRegionBlockSC,
+    VPIRWrapperBlockSC
+  };
 
   using VPBlocksTy = SmallVectorImpl<VPBlockBase *>;
 
@@ -2834,6 +2838,10 @@ class VPBasicBlock : public VPBlockBase {
   /// The VPRecipes held in the order of output instructions to generate.
   RecipeListTy Recipes;
 
+protected:
+  VPBasicBlock(const unsigned char BlockSC, const Twine &Name = "")
+      : VPBlockBase(BlockSC, Name.str()) {}
+
 public:
   VPBasicBlock(const Twine &Name = "", VPRecipeBase *Recipe = nullptr)
       : VPBlockBase(VPBasicBlockSC, Name.str()) {
@@ -2882,7 +2890,8 @@ class VPBasicBlock : public VPBlockBase {
 
   /// Method to support type inquiry through isa, cast, and dyn_cast.
   static inline bool classof(const VPBlockBase *V) {
-    return V->getVPBlockID() == VPBlockBase::VPBasicBlockSC;
+    return V->getVPBlockID() == VPBlockBase::VPBasicBlockSC ||
+           V->getVPBlockID() == VPBlockBase::VPIRWrapperBlockSC;
   }
 
   void insert(VPRecipeBase *Recipe, iterator InsertPt) {
@@ -2951,6 +2960,50 @@ class VPBasicBlock : public VPBlockBase {
   BasicBlock *createEmptyBasicBlock(VPTransformState::CFGState &CFG);
 };
 
+/// A special type of VPBasicBlock that wraps an existing IR basic block.
+/// Recipes of the block get added before the terminator of the wrapped IR basic
+/// block.
+class VPIRWrapperBlock : public VPBasicBlock {
+  BasicBlock *WrappedBlock;
+
+public:
+  VPIRWrapperBlock(BasicBlock *WrappedBlock)
+      : VPBasicBlock(VPIRWrapperBlockSC, WrappedBlock->getName()),
+        WrappedBlock(WrappedBlock) {}
+
+  ~VPIRWrapperBlock() override {}
+
+  static inline bool classof(const VPBlockBase *V) {
+    return V->getVPBlockID() == VPBlockBase::VPIRWrapperBlockSC;
+  }
+
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+  /// Print this VPBsicBlock to \p O, prefixing all lines with \p Indent. \p
+  /// SlotTracker is used to print unnamed VPValue's using consequtive numbers.
+  ///
+  /// Note that the numbering is applied to the whole VPlan, so printing
+  /// individual blocks is consistent with the whole VPlan printing.
+  void print(raw_ostream &O, const Twine &Indent,
+             VPSlotTracker &SlotTracker) const override;
+  using VPBlockBase::print; // Get the print(raw_stream &O) version.
+#endif
+  /// The method which generates the output IR instructions that correspond to
+  /// this VPBasicBlock, thereby "executing" the VPlan.
+  void execute(VPTransformState *State) override;
+
+  VPIRWrapperBlock *clone() override {
+    auto *NewBlock = new VPIRWrapperBlock(WrappedBlock);
+    for (VPRecipeBase &R : *this)
+      NewBlock->appendRecipe(R.clone());
+    return NewBlock;
+  }
+
+  void dropAllReferences(VPValue *NewValue) override {}
+  void resetBlock(BasicBlock *N) { WrappedBlock = N; }
+
+  BasicBlock *getWrappedBlock() { return WrappedBlock; }
+};
+
 /// VPRegionBlock represents a collection of VPBasicBlocks and VPRegionBlocks
 /// which form a Single-Entry-Single-Exiting subgraph of the output IR CFG.
 /// A VPRegionBlock may indicate that its contents are to be replicated several
@@ -3139,12 +3192,12 @@ class VPlan {
   ~VPlan();
 
   /// Create initial VPlan skeleton, having an "entry" VPBasicBlock (wrapping
-  /// original scalar pre-header) which contains SCEV expansions that need to
-  /// happen before the CFG is modified; a VPBasicBlock for the vector
+  /// original scalar pre-header \p PH) which contains SCEV expansions that need
+  /// to happen before the CFG is modified; a VPBasicBlock for the vector
   /// pre-header, followed by a region for the vector loop, followed by the
   /// middle VPBasicBlock.
   static VPlanPtr createInitialVPlan(const SCEV *TripCount,
-                                     ScalarEvolution &PSE);
+                                     ScalarEvolution &PSE, BasicBlock *PH);
 
   /// Prepare the plan for execution, setting up the required live-in values.
   void prepareToExecute(Value *TripCount, Value *VectorTripCount,
@@ -3321,6 +3374,8 @@ class VPlanPrinter {
   /// its successor blocks.
   void dumpBasicBlock(const VPBasicBlock *BasicBlock);
 
+  void dumpIRWrapperBlock(const VPIRWrapperBlock *WrapperBlock);
+
   /// Print a given \p Region of the Plan.
   void dumpRegion(const VPRegionBlock *Region);
 
diff --git a/llvm/test/Transforms/LoopVectorize/vplan-printing-before-execute.ll b/llvm/test/Transforms/LoopVectorize/vplan-printing-before-execute.ll
index ca9dfdc6f6d29..2bb3c898c7cda 100644
--- a/llvm/test/Transforms/LoopVectorize/vplan-printing-before-execute.ll
+++ b/llvm/test/Transforms/LoopVectorize/vplan-printing-before-execute.ll
@@ -13,7 +13,7 @@ define void @test_tc_less_than_16(ptr %A, i64 %N) {
 ; CHECK-NEXT: Live-in vp<[[VTC:%.+]]> = vector-trip-count
 ; CHECK-NEXT: vp<[[TC:%.+]]> = original trip-count
 ; CHECK-EMPTY:
-; CHECK-NEXT: ph:
+; CHECK-NEXT: ir-bb<entry>:
 ; CHECK-NEXT:   EMIT vp<[[TC]]> = EXPAND SCEV (zext i4 (trunc i64 %N to i4) to i64)
 ; CHECK-NEXT: No successors
 ; CHECK-EMPTY:
@@ -45,7 +45,7 @@ define void @test_tc_less_than_16(ptr %A, i64 %N) {
 ; CHECK-NEXT: Live-in vp<[[VFxUF:%.+]]> = VF * UF
 ; CHECK-NEXT: vp<[[TC:%.+]]> = original trip-count
 ; CHECK-EMPTY:
-; CHECK-NEXT: ph:
+; CHECK-NEXT: ir-bb<entry>:
 ; CHECK-NEXT:   EMIT vp<[[TC]]> = EXPAND SCEV (zext i4 (trunc i64 %N to i4) to i64)
 ; CHECK-NEXT: No successors
 ; CHECK-EMPTY:
diff --git a/llvm/unittests/Transforms/Vectorize/VPlanTestBase.h b/llvm/unittests/Transforms/Vectorize/VPlanTestBase.h
index 6cd43f6803130..c658724278fe0 100644
--- a/llvm/unittests/Transforms/Vectorize/VPlanTestBase.h
+++ b/llvm/unittests/Transforms/Vectorize/VPlanTestBase.h
@@ -67,9 +67,10 @@ class VPlanTestBase : public testing::Test {
     assert(!verifyFunction(F) && "input function must be valid");
     doAnalysis(F);
 
-    auto Plan = VPlan::createInitialVPlan(
-        SE->getBackedgeTakenCount(LI->getLoopFor(LoopHeader)), *SE);
-    VPlanHCFGBuilder HCFGBuilder(LI->getLoopFor(LoopHeader), LI.get(), *Plan);
+    Loop *L = LI->getLoopFor(LoopHeader);
+    auto Plan = VPlan::createInitialVPlan(SE->getBackedgeTakenCount(L), *SE,
+                                          L->getLoopPreheader());
+    VPlanHCFGBuilder HCFGBuilder(L, LI.get(), *Plan);
     HCFGBuilder.buildHierarchicalCFG();
     return Plan;
   }
@@ -80,9 +81,10 @@ class VPlanTestBase : public testing::Test {
     assert(!verifyFunction(F) && "input function must be valid");
     doAnalysis(F);
 
-    auto Plan = VPlan::createInitialVPlan(
-        SE->getBackedgeTakenCount(LI->getLoopFor(LoopHeader)), *SE);
-    VPlanHCFGBuilder HCFGBuilder(LI->getLoopFor(LoopHeader), LI.get(), *Plan);
+    Loop *L = LI->getLoopFor(LoopHeader);
+    auto Plan = VPlan::createInitialVPlan(SE->getBackedgeTakenCount(L), *SE,
+                                          L->getLoopPreheader());
+    VPlanHCFGBuilder HCFGBuilder(L, LI.get(), *Plan);
     HCFGBuilder.buildPlainCFG();
     return Plan;
   }

``````````

</details>


https://github.com/llvm/llvm-project/pull/93398


More information about the llvm-commits mailing list