[llvm] [LV][VPlan] Add initial support for CSA vectorization (PR #106560)

via llvm-commits llvm-commits at lists.llvm.org
Thu Aug 29 07:16:33 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-risc-v

@llvm/pr-subscribers-llvm-transforms

Author: Michael Maitland (michaelmaitland)

<details>
<summary>Changes</summary>

This PR contains a series of commits which together implement an initial version of conditional scalar vectorization. I will edit this description with a link to the RFC which gives a more in depth explanation of the changes.

This patch is further tested in https://github.com/llvm/llvm-test-suite/pull/155.

---

Patch is 331.04 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/106560.diff


19 Files Affected:

- (added) llvm/include/llvm/Analysis/CSADescriptors.h (+78) 
- (modified) llvm/include/llvm/Analysis/TargetTransformInfo.h (+9) 
- (modified) llvm/include/llvm/Analysis/TargetTransformInfoImpl.h (+2) 
- (modified) llvm/include/llvm/Transforms/Vectorize/LoopVectorizationLegality.h (+18) 
- (modified) llvm/lib/Analysis/CMakeLists.txt (+1) 
- (added) llvm/lib/Analysis/CSADescriptors.cpp (+73) 
- (modified) llvm/lib/Analysis/TargetTransformInfo.cpp (+4) 
- (modified) llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp (+5) 
- (modified) llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h (+4) 
- (modified) llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp (+30-4) 
- (modified) llvm/lib/Transforms/Vectorize/LoopVectorize.cpp (+203-10) 
- (modified) llvm/lib/Transforms/Vectorize/VPlan.cpp (+5-2) 
- (modified) llvm/lib/Transforms/Vectorize/VPlan.h (+195) 
- (modified) llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp (+367-2) 
- (modified) llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp (+49) 
- (modified) llvm/lib/Transforms/Vectorize/VPlanTransforms.h (+9) 
- (modified) llvm/lib/Transforms/Vectorize/VPlanValue.h (+3) 
- (modified) llvm/lib/Transforms/Vectorize/VPlanVerifier.cpp (+3-3) 
- (added) llvm/test/Transforms/LoopVectorize/RISCV/csa.ll (+4234) 


``````````diff
diff --git a/llvm/include/llvm/Analysis/CSADescriptors.h b/llvm/include/llvm/Analysis/CSADescriptors.h
new file mode 100644
index 00000000000000..3f95b3484d1e22
--- /dev/null
+++ b/llvm/include/llvm/Analysis/CSADescriptors.h
@@ -0,0 +1,78 @@
+//===- llvm/Analysis/CSADescriptors.h - CSA Descriptors --*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file "describes" conditional scalar assignments (CSA).
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/LoopInfo.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/Value.h"
+
+#ifndef LLVM_ANALYSIS_SIFIVECSADESCRIPTORS_H
+#define LLVM_ANALYSIS_SIFIVECSADESCRIPTORS_H
+
+namespace llvm {
+
+/// A Conditional Scalar Assignment (CSA) is an assignment from an initial
+/// scalar that may or may not occur.
+class CSADescriptor {
+  /// If the conditional assignment occurs inside a loop, then Phi chooses
+  /// the value of the assignment from the entry block or the loop body block.
+  PHINode *Phi = nullptr;
+
+  /// The initial value of the CSA. If the condition guarding the assignment is
+  /// not met, then the assignment retains this value.
+  Value *InitScalar = nullptr;
+
+  /// The Instruction that conditionally assigned to inside the loop.
+  Instruction *Assignment = nullptr;
+
+  /// Create a CSA Descriptor that models an invalid CSA.
+  CSADescriptor() = default;
+
+  /// Create a CSA Descriptor that models a valid CSA with its members
+  /// initialized correctly.
+  CSADescriptor(PHINode *Phi, Instruction *Assignment, Value *InitScalar)
+      : Phi(Phi), InitScalar(InitScalar), Assignment(Assignment) {}
+
+public:
+  /// If Phi is the root of a CSA, return the CSADescriptor of the CSA rooted by
+  /// Phi. Otherwise, return a CSADescriptor with IsValidCSA set to false.
+  static CSADescriptor isCSAPhi(PHINode *Phi, Loop *TheLoop);
+
+  operator bool() const { return isValid(); }
+
+  /// Returns whether SI is the Assignment in CSA
+  static bool isCSASelect(CSADescriptor Desc, SelectInst *SI) {
+    return Desc.getAssignment() == SI;
+  }
+
+  /// Return whether this CSADescriptor models a valid CSA.
+  bool isValid() const { return Phi && InitScalar && Assignment; }
+
+  /// Return the PHI that roots this CSA.
+  PHINode *getPhi() const { return Phi; }
+
+  /// Return the initial value of the CSA. This is the value if the conditional
+  /// assignment does not occur.
+  Value *getInitScalar() const { return InitScalar; }
+
+  /// The Instruction that is used after the loop
+  Instruction *getAssignment() const { return Assignment; }
+
+  /// Return the condition that this CSA is conditional upon.
+  Value *getCond() const {
+    if (auto *SI = dyn_cast_or_null<SelectInst>(Assignment))
+      return SI->getCondition();
+    return nullptr;
+  }
+};
+} // namespace llvm
+
+#endif // LLVM_ANALYSIS_CSADESCRIPTORS_H
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index b2124c6106198e..d0192f7d90a812 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -1767,6 +1767,10 @@ class TargetTransformInfo {
         : EVLParamStrategy(EVLParamStrategy), OpStrategy(OpStrategy) {}
   };
 
+  /// \returns true if the loop vectorizer should vectorize conditional
+  /// scalar assignments for the target.
+  bool enableCSAVectorization() const;
+
   /// \returns How the target needs this vector-predicated operation to be
   /// transformed.
   VPLegalization getVPLegalizationStrategy(const VPIntrinsic &PI) const;
@@ -2175,6 +2179,7 @@ class TargetTransformInfo::Concept {
   virtual bool supportsScalableVectors() const = 0;
   virtual bool hasActiveVectorLength(unsigned Opcode, Type *DataType,
                                      Align Alignment) const = 0;
+  virtual bool enableCSAVectorization() const = 0;
   virtual VPLegalization
   getVPLegalizationStrategy(const VPIntrinsic &PI) const = 0;
   virtual bool hasArmWideBranch(bool Thumb) const = 0;
@@ -2940,6 +2945,10 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
     return Impl.hasActiveVectorLength(Opcode, DataType, Alignment);
   }
 
+  bool enableCSAVectorization() const override {
+    return Impl.enableCSAVectorization();
+  }
+
   VPLegalization
   getVPLegalizationStrategy(const VPIntrinsic &PI) const override {
     return Impl.getVPLegalizationStrategy(PI);
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 11b07ac0b7fc47..dbf0cf888e168a 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -956,6 +956,8 @@ class TargetTransformInfoImplBase {
     return false;
   }
 
+  bool enableCSAVectorization() const { return false; }
+
   TargetTransformInfo::VPLegalization
   getVPLegalizationStrategy(const VPIntrinsic &PI) const {
     return TargetTransformInfo::VPLegalization(
diff --git a/llvm/include/llvm/Transforms/Vectorize/LoopVectorizationLegality.h b/llvm/include/llvm/Transforms/Vectorize/LoopVectorizationLegality.h
index 0f4d1355dd2bfe..7ef29a8cb36e49 100644
--- a/llvm/include/llvm/Transforms/Vectorize/LoopVectorizationLegality.h
+++ b/llvm/include/llvm/Transforms/Vectorize/LoopVectorizationLegality.h
@@ -27,6 +27,7 @@
 #define LLVM_TRANSFORMS_VECTORIZE_LOOPVECTORIZATIONLEGALITY_H
 
 #include "llvm/ADT/MapVector.h"
+#include "llvm/Analysis/CSADescriptors.h"
 #include "llvm/Analysis/LoopAccessAnalysis.h"
 #include "llvm/Support/TypeSize.h"
 #include "llvm/Transforms/Utils/LoopUtils.h"
@@ -257,6 +258,10 @@ class LoopVectorizationLegality {
   /// induction descriptor.
   using InductionList = MapVector<PHINode *, InductionDescriptor>;
 
+  /// CSAList contains the CSA descriptors for all the CSAs that were found
+  /// in the loop, rooted by their phis.
+  using CSAList = MapVector<PHINode *, CSADescriptor>;
+
   /// RecurrenceSet contains the phi nodes that are recurrences other than
   /// inductions and reductions.
   using RecurrenceSet = SmallPtrSet<const PHINode *, 8>;
@@ -309,6 +314,12 @@ class LoopVectorizationLegality {
   /// Returns True if V is a Phi node of an induction variable in this loop.
   bool isInductionPhi(const Value *V) const;
 
+  /// Returns the CSAs found in the loop.
+  const CSAList& getCSAs() const { return CSAs; }
+
+  /// Returns true if Phi is the root of a CSA in the loop.
+  bool isCSAPhi(PHINode *Phi) const { return CSAs.count(Phi) != 0; }
+
   /// Returns a pointer to the induction descriptor, if \p Phi is an integer or
   /// floating point induction.
   const InductionDescriptor *getIntOrFpInductionDescriptor(PHINode *Phi) const;
@@ -463,6 +474,10 @@ class LoopVectorizationLegality {
   void addInductionPhi(PHINode *Phi, const InductionDescriptor &ID,
                        SmallPtrSetImpl<Value *> &AllowedExit);
 
+  // Updates the vetorization state by adding \p Phi to the CSA list.
+  void addCSAPhi(PHINode *Phi, const CSADescriptor &CSADesc,
+                 SmallPtrSetImpl<Value *> &AllowedExit);
+
   /// The loop that we evaluate.
   Loop *TheLoop;
 
@@ -507,6 +522,9 @@ class LoopVectorizationLegality {
   /// variables can be pointers.
   InductionList Inductions;
 
+  /// Holds the conditional scalar assignments
+  CSAList CSAs;
+
   /// Holds all the casts that participate in the update chain of the induction
   /// variables, and that have been proven to be redundant (possibly under a
   /// runtime guard). These casts can be ignored when creating the vectorized
diff --git a/llvm/lib/Analysis/CMakeLists.txt b/llvm/lib/Analysis/CMakeLists.txt
index 393803fad89383..24ca426990d9ed 100644
--- a/llvm/lib/Analysis/CMakeLists.txt
+++ b/llvm/lib/Analysis/CMakeLists.txt
@@ -46,6 +46,7 @@ add_llvm_component_library(LLVMAnalysis
   CostModel.cpp
   CodeMetrics.cpp
   ConstantFolding.cpp
+  CSADescriptors.cpp
   CtxProfAnalysis.cpp
   CycleAnalysis.cpp
   DDG.cpp
diff --git a/llvm/lib/Analysis/CSADescriptors.cpp b/llvm/lib/Analysis/CSADescriptors.cpp
new file mode 100644
index 00000000000000..d0377c8c16de33
--- /dev/null
+++ b/llvm/lib/Analysis/CSADescriptors.cpp
@@ -0,0 +1,73 @@
+//=== llvm/Analysis/CSADescriptors.cpp - CSA Descriptors -*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file "describes" conditional scalar assignments (CSA).
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/CSADescriptors.h"
+#include "llvm/IR/PatternMatch.h"
+#include "llvm/IR/Type.h"
+
+using namespace llvm;
+using namespace llvm::PatternMatch;
+
+#define DEBUG_TYPE "csa-descriptors"
+
+CSADescriptor CSADescriptor::isCSAPhi(PHINode *Phi, Loop *TheLoop) {
+  // Return CSADescriptor that describes a CSA that matches one of these
+  // patterns:
+  //   phi loop_inv, (select cmp, value, phi)
+  //   phi loop_inv, (select cmp, phi, value)
+  //   phi (select cmp, value, phi), loop_inv
+  //   phi (select cmp, phi, value), loop_inv
+  // If the CSA does not match any of these paterns, return a CSADescriptor
+  // that describes an InvalidCSA.
+
+  // Must be a scalar
+  Type *Type = Phi->getType();
+  if (!Type->isIntegerTy() && !Type->isFloatingPointTy() &&
+      !Type->isPointerTy())
+    return CSADescriptor();
+
+  // Match phi loop_inv, (select cmp, value, phi)
+  //    or phi loop_inv, (select cmp, phi, value)
+  //    or phi (select cmp, value, phi), loop_inv
+  //    or phi (select cmp, phi, value), loop_inv
+  if (Phi->getNumIncomingValues() != 2)
+    return CSADescriptor();
+  auto SelectInstIt = find_if(Phi->incoming_values(), [&Phi](Use &U) {
+    return match(U.get(), m_Select(m_Value(), m_Specific(Phi), m_Value())) ||
+           match(U.get(), m_Select(m_Value(), m_Value(), m_Specific(Phi)));
+  });
+  if (SelectInstIt == Phi->incoming_values().end())
+    return CSADescriptor();
+  auto LoopInvIt = find_if(Phi->incoming_values(), [&](Use &U) {
+    return U.get() != *SelectInstIt && TheLoop->isLoopInvariant(U.get());
+  });
+  if (LoopInvIt == Phi->incoming_values().end())
+    return CSADescriptor();
+
+  // Phi or Sel must be used only outside the loop,
+  // excluding if Phi use Sel or Sel use Phi
+  auto IsOnlyUsedOutsideLoop = [=](Value *V, Value *Ignore) {
+    return all_of(V->users(), [Ignore, TheLoop](User *U) {
+      if (U == Ignore)
+        return true;
+      if (auto *I = dyn_cast<Instruction>(U))
+        return !TheLoop->contains(I);
+      return true;
+    });
+  };
+  auto *Sel = cast<SelectInst>(SelectInstIt->get());
+  auto *LoopInv = LoopInvIt->get();
+  if (!IsOnlyUsedOutsideLoop(Phi, Sel) || !IsOnlyUsedOutsideLoop(Sel, Phi))
+    return CSADescriptor();
+
+  return CSADescriptor(Phi, Sel, LoopInv);
+}
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index 2c26493bd3f1ca..4f882475b74e74 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -1304,6 +1304,10 @@ bool TargetTransformInfo::preferEpilogueVectorization() const {
   return TTIImpl->preferEpilogueVectorization();
 }
 
+bool TargetTransformInfo::enableCSAVectorization() const {
+  return TTIImpl->enableCSAVectorization();
+}
+
 TargetTransformInfo::VPLegalization
 TargetTransformInfo::getVPLegalizationStrategy(const VPIntrinsic &VPI) const {
   return TTIImpl->getVPLegalizationStrategy(VPI);
diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
index 537c62bb0aacd1..b76f2fc72c6f47 100644
--- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
+++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
@@ -1985,6 +1985,11 @@ bool RISCVTTIImpl::isLSRCostLess(const TargetTransformInfo::LSRCost &C1,
                   C2.ScaleCost, C2.ImmCost, C2.SetupCost);
 }
 
+bool RISCVTTIImpl::enableCSAVectorization() const {
+  return ST->hasVInstructions() &&
+         ST->getProcFamily() == RISCVSubtarget::SiFive7;
+}
+
 bool RISCVTTIImpl::isLegalMaskedCompressStore(Type *DataTy, Align Alignment) {
   auto *VTy = dyn_cast<VectorType>(DataTy);
   if (!VTy || VTy->isScalableTy())
diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
index cc69e1d118b5a1..17245150ec10ae 100644
--- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
+++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
@@ -287,6 +287,10 @@ class RISCVTTIImpl : public BasicTTIImplBase<RISCVTTIImpl> {
     return TLI->isVScaleKnownToBeAPowerOfTwo();
   }
 
+  /// \returns true if the loop vectorizer should vectorize conditional
+  /// scalar assignments for the target.
+  bool enableCSAVectorization() const;
+
   /// \returns How the target needs this vector-predicated operation to be
   /// transformed.
   TargetTransformInfo::VPLegalization
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp
index 66a779da8c25bc..9633ba9cc70ee9 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp
@@ -79,6 +79,10 @@ static cl::opt<LoopVectorizeHints::ScalableForceKind>
                 "Scalable vectorization is available and favored when the "
                 "cost is inconclusive.")));
 
+static cl::opt<bool>
+    EnableCSA("enable-csa-vectorization", cl::init(false), cl::Hidden,
+              cl::desc("Control whether CSA loop vectorization is enabled"));
+
 /// Maximum vectorization interleave count.
 static const unsigned MaxInterleaveFactor = 16;
 
@@ -749,6 +753,15 @@ bool LoopVectorizationLegality::setupOuterLoopInductions() {
   return llvm::all_of(Header->phis(), IsSupportedPhi);
 }
 
+void LoopVectorizationLegality::addCSAPhi(
+    PHINode *Phi, const CSADescriptor &CSADesc,
+    SmallPtrSetImpl<Value *> &AllowedExit) {
+  assert(CSADesc.isValid() && "Expected Valid CSADescriptor");
+  LLVM_DEBUG(dbgs() << "LV: found legal CSA opportunity" << *Phi << "\n");
+  AllowedExit.insert(Phi);
+  CSAs.insert({Phi, CSADesc});
+}
+
 /// Checks if a function is scalarizable according to the TLI, in
 /// the sense that it should be vectorized and then expanded in
 /// multiple scalar calls. This is represented in the
@@ -866,14 +879,23 @@ bool LoopVectorizationLegality::canVectorizeInstrs() {
           continue;
         }
 
-        // As a last resort, coerce the PHI to a AddRec expression
-        // and re-try classifying it a an induction PHI.
+        // Try to coerce the PHI to a AddRec expression and re-try classifying
+        // it a an induction PHI.
         if (InductionDescriptor::isInductionPHI(Phi, TheLoop, PSE, ID, true) &&
             !IsDisallowedStridedPointerInduction(ID)) {
           addInductionPhi(Phi, ID, AllowedExit);
           continue;
         }
 
+        // Check if the PHI can be classified as a CSA PHI.
+        if (EnableCSA || (TTI->enableCSAVectorization() &&
+                          EnableCSA.getNumOccurrences() == 0)) {
+          if (auto CSADesc = CSADescriptor::isCSAPhi(Phi, TheLoop)) {
+            addCSAPhi(Phi, CSADesc, AllowedExit);
+            continue;
+          }
+        }
+
         reportVectorizationFailure("Found an unidentified PHI",
             "value that could not be identified as "
             "reduction is used outside the loop",
@@ -1555,11 +1577,15 @@ bool LoopVectorizationLegality::canFoldTailByMasking() const {
   for (const auto &Reduction : getReductionVars())
     ReductionLiveOuts.insert(Reduction.second.getLoopExitInstr());
 
+  SmallPtrSet<const Value *, 8> CSALiveOuts;
+  for (const auto &CSA: getCSAs())
+    CSALiveOuts.insert(CSA.second.getAssignment());
+
   // TODO: handle non-reduction outside users when tail is folded by masking.
   for (auto *AE : AllowedExit) {
     // Check that all users of allowed exit values are inside the loop or
-    // are the live-out of a reduction.
-    if (ReductionLiveOuts.count(AE))
+    // are the live-out of a reduction or a CSA
+    if (ReductionLiveOuts.count(AE) || CSALiveOuts.count(AE))
       continue;
     for (User *U : AE->users()) {
       Instruction *UI = cast<Instruction>(U);
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 56f51e14a6eba9..5e45f500482826 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -180,6 +180,8 @@ const char LLVMLoopVectorizeFollowupEpilogue[] =
 STATISTIC(LoopsVectorized, "Number of loops vectorized");
 STATISTIC(LoopsAnalyzed, "Number of loops analyzed for vectorization");
 STATISTIC(LoopsEpilogueVectorized, "Number of epilogues vectorized");
+STATISTIC(CSAsVectorized,
+          "Number of conditional scalar assignments vectorized");
 
 static cl::opt<bool> EnableEpilogueVectorization(
     "enable-epilogue-vectorization", cl::init(true), cl::Hidden,
@@ -500,6 +502,10 @@ class InnerLoopVectorizer {
   virtual std::pair<BasicBlock *, Value *>
   createVectorizedLoopSkeleton(const SCEV2ValueTy &ExpandedSCEVs);
 
+  /// For all vectorized CSAs, replace uses of live-out scalar from the orignal
+  /// loop with the extracted scalar from the vector loop for.
+  void fixCSALiveOuts(VPTransformState &State, VPlan &Plan);
+
   /// Fix the vectorized code, taking care of header phi's, live-outs, and more.
   void fixVectorizedLoop(VPTransformState &State, VPlan &Plan);
 
@@ -2932,6 +2938,25 @@ LoopVectorizationCostModel::getVectorIntrinsicCost(CallInst *CI,
                                    TargetTransformInfo::TCK_RecipThroughput);
 }
 
+void InnerLoopVectorizer::fixCSALiveOuts(VPTransformState &State, VPlan &Plan) {
+  for (const auto &CSA: Plan.getCSAStates()) {
+    VPCSADataUpdateRecipe *VPDataUpdate = CSA.second->getDataUpdate();
+    assert(VPDataUpdate &&
+           "VPDataUpdate must have been introduced prior to fixing live outs");
+    Value *V = VPDataUpdate->getUnderlyingValue();
+    Value *ExtractedScalar = State.get(CSA.second->getExtractScalarRecipe(), 0,
+                                       /*NeedsScalar=*/true);
+    // Fix LCSSAPhis
+    llvm::SmallPtrSet<PHINode *, 2> ToFix;
+    for (User *U : V->users())
+      if (auto *Phi = dyn_cast<PHINode>(U);
+          Phi && Phi->getParent() == LoopExitBlock)
+        ToFix.insert(Phi);
+    for (PHINode *Phi : ToFix)
+      Phi->addIncoming(ExtractedScalar, LoopMiddleBlock);
+  }
+}
+
 void InnerLoopVectorizer::fixVectorizedLoop(VPTransformState &State,
                                             VPlan &Plan) {
   // Fix widened non-induction PHIs by setting up the PHI operands.
@@ -2972,6 +2997,8 @@ void InnerLoopVectorizer::fixVectorizedLoop(VPTransformState &State,
                    getOrCreateVectorTripCount(VectorLoop->getLoopPreheader()),
                    IVEndValues[Entry.first], LoopMiddleBlock,
                    VectorLoop->getHeader(), Plan, State);
+
+    fixCSALiveOuts(State, Plan);
   }
 
   // Fix live-out phis not already fixed earlier.
@@ -4110,7 +4137,6 @@ LoopVectorizationCostModel::computeMaxVF(ElementCount UserVF, unsigned UserIC) {
   // found modulo the vectorization factor is not zero, try to fold the tail
   // by masking.
   // FIXME: look for a smaller MaxVF that does divide TC rather than masking.
-  setTailFoldingStyles(MaxFactors.ScalableVF.isScalable(), UserIC);
   if (foldTailByMasking()) {
     if (getTailFoldingStyle() == TailFoldingStyle::DataWithEVL) {
       LLVM_DEBUG(
@@ -4482,6 +4508,9 @@ static bool willGenerateVectors(VPlan &Plan, ElementCount VF,
       case VPDef::VPEVLBasedIVPHISC:
       case VPDef::VPPredInstPHISC:
       case VPDef::VPBranchOnMaskSC:
+      case VPRecipeBase::VPCSADataUpdateSC:
+      case VPRecipeBase::VPCSAExtractScalarSC:
+      case VPRecipeBase::VPCSAHe...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/106560


More information about the llvm-commits mailing list