[llvm] [LV][VPlan] Add initial support for CSA vectorization (PR #106560)

via llvm-commits llvm-commits at lists.llvm.org
Thu Aug 29 07:19:49 PDT 2024


github-actions[bot] wrote:

<!--LLVM CODE FORMAT COMMENT: {clang-format}-->


:warning: C/C++ code formatter, clang-format found issues in your code. :warning:

<details>
<summary>
You can test this locally with the following command:
</summary>

``````````bash
git-clang-format --diff 18c79ca3607bfe9cc6fd083186f3b462f5abff7e c15cf30575340782bf6d2596ad84f118230e0176 --extensions cpp,h -- llvm/include/llvm/Analysis/CSADescriptors.h llvm/lib/Analysis/CSADescriptors.cpp llvm/include/llvm/Analysis/TargetTransformInfo.h llvm/include/llvm/Analysis/TargetTransformInfoImpl.h llvm/include/llvm/Transforms/Vectorize/LoopVectorizationLegality.h llvm/lib/Analysis/TargetTransformInfo.cpp llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp llvm/lib/Transforms/Vectorize/LoopVectorize.cpp llvm/lib/Transforms/Vectorize/VPlan.cpp llvm/lib/Transforms/Vectorize/VPlan.h llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp llvm/lib/Transforms/Vectorize/VPlanTransforms.h llvm/lib/Transforms/Vectorize/VPlanValue.h llvm/lib/Transforms/Vectorize/VPlanVerifier.cpp
``````````

</details>

<details>
<summary>
View the diff from clang-format here.
</summary>

``````````diff
diff --git a/llvm/include/llvm/Transforms/Vectorize/LoopVectorizationLegality.h b/llvm/include/llvm/Transforms/Vectorize/LoopVectorizationLegality.h
index 7ef29a8cb3..a492af0b38 100644
--- a/llvm/include/llvm/Transforms/Vectorize/LoopVectorizationLegality.h
+++ b/llvm/include/llvm/Transforms/Vectorize/LoopVectorizationLegality.h
@@ -315,7 +315,7 @@ public:
   bool isInductionPhi(const Value *V) const;
 
   /// Returns the CSAs found in the loop.
-  const CSAList& getCSAs() const { return CSAs; }
+  const CSAList &getCSAs() const { return CSAs; }
 
   /// Returns true if Phi is the root of a CSA in the loop.
   bool isCSAPhi(PHINode *Phi) const { return CSAs.count(Phi) != 0; }
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp
index 9633ba9cc7..fac10e55a9 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp
@@ -1578,7 +1578,7 @@ bool LoopVectorizationLegality::canFoldTailByMasking() const {
     ReductionLiveOuts.insert(Reduction.second.getLoopExitInstr());
 
   SmallPtrSet<const Value *, 8> CSALiveOuts;
-  for (const auto &CSA: getCSAs())
+  for (const auto &CSA : getCSAs())
     CSALiveOuts.insert(CSA.second.getAssignment());
 
   // TODO: handle non-reduction outside users when tail is folded by masking.
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 5e45f50048..92681cbfe9 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -2939,7 +2939,7 @@ LoopVectorizationCostModel::getVectorIntrinsicCost(CallInst *CI,
 }
 
 void InnerLoopVectorizer::fixCSALiveOuts(VPTransformState &State, VPlan &Plan) {
-  for (const auto &CSA: Plan.getCSAStates()) {
+  for (const auto &CSA : Plan.getCSAStates()) {
     VPCSADataUpdateRecipe *VPDataUpdate = CSA.second->getDataUpdate();
     assert(VPDataUpdate &&
            "VPDataUpdate must have been introduced prior to fixing live outs");
@@ -7276,11 +7276,9 @@ InstructionCost LoopVectorizationPlanner::cost(VPlan &Plan,
 /// not have corresponding recipes in \p Plan and are not marked to be ignored
 /// in \p CostCtx. This means the VPlan contains simplification that the legacy
 /// cost-model did not account for.
-static bool
-planContainsAdditionalSimplifications(VPlan &Plan, ElementCount VF,
-                                      VPCostContext &CostCtx, Loop *TheLoop,
-                                      LoopVectorizationCostModel &CM,
-                                      LoopVectorizationLegality &Legal) {
+static bool planContainsAdditionalSimplifications(
+    VPlan &Plan, ElementCount VF, VPCostContext &CostCtx, Loop *TheLoop,
+    LoopVectorizationCostModel &CM, LoopVectorizationLegality &Legal) {
 
   // CSA cost is more complicated since there is significant overhead in the
   // preheader and middle block. It also contains recipes that are not backed by
@@ -8525,7 +8523,7 @@ VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr,
       PhiRecipe = new VPReductionPHIRecipe(Phi, RdxDesc, *StartV,
                                            CM.isInLoopReduction(Phi),
                                            CM.useOrderedReductions(RdxDesc));
-    } else if (Legal->isFixedOrderRecurrence(Phi)){
+    } else if (Legal->isFixedOrderRecurrence(Phi)) {
       // TODO: Currently fixed-order recurrences are modeled as chains of
       // first-order recurrences. If there are no users of the intermediate
       // recurrences in the chain, the fixed order recurrence should be modeled
@@ -8620,8 +8618,8 @@ addCSAPreprocessRecipes(const LoopVectorizationLegality::CSAList &CSAs,
       continue;
     }
 
-    auto *VPInitMask = new VPInstruction(VPInstruction::CSAInitMask, {}, DL,
-                                         "csa.init.mask");
+    auto *VPInitMask =
+        new VPInstruction(VPInstruction::CSAInitMask, {}, DL, "csa.init.mask");
     auto *VPInitData = new VPInstruction(VPInstruction::CSAInitData,
                                          {VPInitScalar}, DL, "csa.init.data");
     PreheaderVPBB->appendRecipe(VPInitMask);
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index 7d09ce6e96..3859e8e1a7 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -3536,9 +3536,7 @@ public:
                                      bool RequiresScalarEpilogueCheck,
                                      bool TailFolded, Loop *TheLoop);
 
-  void addCSAState(PHINode *Phi, VPCSAState * S) {
-    CSAStates.insert({Phi , S});
-  }
+  void addCSAState(PHINode *Phi, VPCSAState *S) { CSAStates.insert({Phi, S}); }
 
   MapVector<PHINode *, VPCSAState *> const &getCSAStates() const {
     return CSAStates;
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index de3fd7a912..787f410214 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -708,18 +708,18 @@ Value *VPInstruction::generatePerPart(VPTransformState &State, unsigned Part) {
     return V;
   }
   case VPInstruction::CSAMaskSel: {
-      Value *WidenedCond = State.get(getOperand(0), Part);
-      Value *MaskPhi = State.get(getOperand(1), Part);
-      Value *AnyActive = State.get(getOperand(2), Part, /*NeedsScalar=*/true);
-      // If not the first Part, use the mask from the previous unrolled Part
-      Value *OldMask = Part == 0 ? MaskPhi : State.get(this, Part - 1);
-      Value *MaskSel = State.Builder.CreateSelect(AnyActive, WidenedCond,
-                                                  OldMask, "csa.mask.sel");
-      // MaskPhi wants to use the most recently updated mask. That's the one
-      // that corresponds to the last Part.
-      if (Part == State.UF - 1)
-        cast<PHINode>(MaskPhi)->addIncoming(MaskSel, State.CFG.PrevBB);
-      return MaskSel;
+    Value *WidenedCond = State.get(getOperand(0), Part);
+    Value *MaskPhi = State.get(getOperand(1), Part);
+    Value *AnyActive = State.get(getOperand(2), Part, /*NeedsScalar=*/true);
+    // If not the first Part, use the mask from the previous unrolled Part
+    Value *OldMask = Part == 0 ? MaskPhi : State.get(this, Part - 1);
+    Value *MaskSel = State.Builder.CreateSelect(AnyActive, WidenedCond, OldMask,
+                                                "csa.mask.sel");
+    // MaskPhi wants to use the most recently updated mask. That's the one
+    // that corresponds to the last Part.
+    if (Part == State.UF - 1)
+      cast<PHINode>(MaskPhi)->addIncoming(MaskSel, State.CFG.PrevBB);
+    return MaskSel;
   }
   case VPInstruction::CSAAnyActive: {
     Value *WidenedCond = State.get(getOperand(0), Part);
@@ -727,8 +727,8 @@ Value *VPInstruction::generatePerPart(VPTransformState &State, unsigned Part) {
   }
   case VPInstruction::CSAAnyActiveEVL: {
     Value *WidenedCond = State.get(getOperand(0), Part);
-    Value *AllOnesMask = Constant::getAllOnesValue(VectorType::get(
-        Type::getInt1Ty(State.Builder.getContext()), State.VF));
+    Value *AllOnesMask = Constant::getAllOnesValue(
+        VectorType::get(Type::getInt1Ty(State.Builder.getContext()), State.VF));
     Value *EVL = State.get(getOperand(1), Part, /*NeedsScalar=*/true);
 
     Value *StartValue =
@@ -2168,7 +2168,7 @@ InstructionCost VPCSAHeaderPHIRecipe::computeCost(ElementCount VF,
 
 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
 void VPCSADataUpdateRecipe::print(raw_ostream &O, const Twine &Indent,
-                                 VPSlotTracker &SlotTracker) const {
+                                  VPSlotTracker &SlotTracker) const {
   O << Indent << "EMIT ";
   printAsOperand(O, SlotTracker);
   O << " = csa-data-update ";
@@ -2224,7 +2224,7 @@ InstructionCost VPCSADataUpdateRecipe::computeCost(ElementCount VF,
 
 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
 void VPCSAExtractScalarRecipe::print(raw_ostream &O, const Twine &Indent,
-                                 VPSlotTracker &SlotTracker) const {
+                                     VPSlotTracker &SlotTracker) const {
   O << Indent << "EMIT ";
   printAsOperand(O, SlotTracker);
   O << " = CSA-EXTRACT-SCALAR ";
@@ -2308,8 +2308,8 @@ VPCSAExtractScalarRecipe::computeCost(ElementCount VF,
                                     CostKind);
   } else {
     // ActiveLaneIdxs
-    C += TTI.getArithmeticInstrCost(Instruction::Select, MaskTy->getScalarType(),
-                                    CostKind);
+    C += TTI.getArithmeticInstrCost(Instruction::Select,
+                                    MaskTy->getScalarType(), CostKind);
     // MaybeLastIdx
     C += TTI.getMinMaxReductionCost(Intrinsic::smax, Int32VTy, FastMathFlags(),
                                     CostKind);
@@ -2320,7 +2320,7 @@ VPCSAExtractScalarRecipe::computeCost(ElementCount VF,
     C += TTI.getArithmeticInstrCost(Instruction::ICmp, MaskTy->getScalarType(),
                                     CostKind);
     // And
-    C += TTI.getArithmeticInstrCost(Instruction::And , MaskTy->getScalarType(),
+    C += TTI.getArithmeticInstrCost(Instruction::And, MaskTy->getScalarType(),
                                     CostKind);
     // LastIdx
     C += TTI.getArithmeticInstrCost(Instruction::Select, VTy->getScalarType(),

``````````

</details>


https://github.com/llvm/llvm-project/pull/106560


More information about the llvm-commits mailing list