[llvm] [VPlan] Introduce recipes for VP loads and stores. (PR #87816)

Florian Hahn via llvm-commits llvm-commits at lists.llvm.org
Thu Apr 18 08:04:13 PDT 2024


https://github.com/fhahn updated https://github.com/llvm/llvm-project/pull/87816

>From 4533d9250262ed5b076cf826167551b83d3edf89 Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Wed, 17 Apr 2024 15:54:51 +0100
Subject: [PATCH 1/2] [VPlan] Introduce recipes for VP loads and stores.

Introduce new subclasses of VPWidenMemoryRecipe for VP
(vector-predicated) loads and stores to address multiple TODOs from
https://github.com/llvm/llvm-project/pull/76172

Note that the introduction of the new recipes also improves code-gen for
VP gather/scatters by removing the redundant header mask. With the new
approach, it is not sufficient to look at users of the widened canonical
IV to find all uses of the header mask.

In some cases, a widened IV is used instead of separately widening the
canonical IV. To handle those cases, iterate over all recipes in the
vector loop region to make sure all widened memory recipes are
processed.

Depends on https://github.com/llvm/llvm-project/pull/87411.
---
 .../Transforms/Vectorize/LoopVectorize.cpp    | 161 +++++++++---------
 llvm/lib/Transforms/Vectorize/VPlan.h         | 108 ++++++++++--
 .../Transforms/Vectorize/VPlanAnalysis.cpp    |   2 +-
 .../lib/Transforms/Vectorize/VPlanRecipes.cpp |  22 ++-
 .../Transforms/Vectorize/VPlanTransforms.cpp  | 100 ++++++-----
 llvm/lib/Transforms/Vectorize/VPlanValue.h    |   2 +
 ...rize-force-tail-with-evl-gather-scatter.ll |  14 +-
 .../RISCV/vplan-vp-intrinsics.ll              |   6 +-
 8 files changed, 263 insertions(+), 152 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index a8272f45025358..48f44134cb4f01 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -9316,52 +9316,6 @@ void VPReplicateRecipe::execute(VPTransformState &State) {
       State.ILV->scalarizeInstruction(UI, this, VPIteration(Part, Lane), State);
 }
 
-/// Creates either vp_store or vp_scatter intrinsics calls to represent
-/// predicated store/scatter.
-static Instruction *
-lowerStoreUsingVectorIntrinsics(IRBuilderBase &Builder, Value *Addr,
-                                Value *StoredVal, bool IsScatter, Value *Mask,
-                                Value *EVL, const Align &Alignment) {
-  CallInst *Call;
-  if (IsScatter) {
-    Call = Builder.CreateIntrinsic(Type::getVoidTy(EVL->getContext()),
-                                   Intrinsic::vp_scatter,
-                                   {StoredVal, Addr, Mask, EVL});
-  } else {
-    VectorBuilder VBuilder(Builder);
-    VBuilder.setEVL(EVL).setMask(Mask);
-    Call = cast<CallInst>(VBuilder.createVectorInstruction(
-        Instruction::Store, Type::getVoidTy(EVL->getContext()),
-        {StoredVal, Addr}));
-  }
-  Call->addParamAttr(
-      1, Attribute::getWithAlignment(Call->getContext(), Alignment));
-  return Call;
-}
-
-/// Creates either vp_load or vp_gather intrinsics calls to represent
-/// predicated load/gather.
-static Instruction *lowerLoadUsingVectorIntrinsics(IRBuilderBase &Builder,
-                                                   VectorType *DataTy,
-                                                   Value *Addr, bool IsGather,
-                                                   Value *Mask, Value *EVL,
-                                                   const Align &Alignment) {
-  CallInst *Call;
-  if (IsGather) {
-    Call =
-        Builder.CreateIntrinsic(DataTy, Intrinsic::vp_gather, {Addr, Mask, EVL},
-                                nullptr, "wide.masked.gather");
-  } else {
-    VectorBuilder VBuilder(Builder);
-    VBuilder.setEVL(EVL).setMask(Mask);
-    Call = cast<CallInst>(VBuilder.createVectorInstruction(
-        Instruction::Load, DataTy, Addr, "vp.op.load"));
-  }
-  Call->addParamAttr(
-      0, Attribute::getWithAlignment(Call->getContext(), Alignment));
-  return Call;
-}
-
 void VPWidenLoadRecipe::execute(VPTransformState &State) {
   auto *LI = cast<LoadInst>(&Ingredient);
 
@@ -9383,24 +9337,7 @@ void VPWidenLoadRecipe::execute(VPTransformState &State) {
         Mask = Builder.CreateVectorReverse(Mask, "reverse");
     }
 
-    // TODO: split this into several classes for better design.
-    if (State.EVL) {
-      assert(State.UF == 1 && "Expected only UF == 1 when vectorizing with "
-                              "explicit vector length.");
-      assert(cast<VPInstruction>(State.EVL)->getOpcode() ==
-                 VPInstruction::ExplicitVectorLength &&
-             "EVL must be VPInstruction::ExplicitVectorLength.");
-      Value *EVL = State.get(State.EVL, VPIteration(0, 0));
-      // If EVL is not nullptr, then EVL must be a valid value set during plan
-      // creation, possibly default value = whole vector register length. EVL
-      // is created only if TTI prefers predicated vectorization, thus if EVL
-      // is not nullptr it also implies preference for predicated
-      // vectorization.
-      // FIXME: Support reverse loading after vp_reverse is added.
-      NewLI = lowerLoadUsingVectorIntrinsics(
-          Builder, DataTy, State.get(getAddr(), Part, !CreateGather),
-          CreateGather, Mask, EVL, Alignment);
-    } else if (CreateGather) {
+    if (CreateGather) {
       Value *VectorGep = State.get(getAddr(), Part);
       NewLI = Builder.CreateMaskedGather(DataTy, VectorGep, Alignment, Mask,
                                          nullptr, "wide.masked.gather");
@@ -9425,6 +9362,44 @@ void VPWidenLoadRecipe::execute(VPTransformState &State) {
   }
 }
 
+void VPWidenVPLoadRecipe::execute(VPTransformState &State) {
+  assert(State.UF == 1 && "Expected only UF == 1 when vectorizing with "
+                          "explicit vector length.");
+  // FIXME: Support reverse loading after vp_reverse is added.
+  assert(!isReverse() && "Reverse loads are not implemented yet.");
+
+  auto *LI = cast<LoadInst>(&Ingredient);
+
+  Type *ScalarDataTy = getLoadStoreType(&Ingredient);
+  auto *DataTy = VectorType::get(ScalarDataTy, State.VF);
+  const Align Alignment = getLoadStoreAlignment(&Ingredient);
+  bool CreateGather = !isConsecutive();
+
+  auto &Builder = State.Builder;
+  State.setDebugLocFrom(getDebugLoc());
+  CallInst *NewLI;
+  Value *EVL = State.get(getEVL(), VPIteration(0, 0));
+  Value *Addr = State.get(getAddr(), 0, !CreateGather);
+  Value *Mask =
+      getMask() ? State.get(getMask(), 0)
+                : Mask = Builder.CreateVectorSplat(State.VF, Builder.getTrue());
+  if (CreateGather) {
+    NewLI =
+        Builder.CreateIntrinsic(DataTy, Intrinsic::vp_gather, {Addr, Mask, EVL},
+                                nullptr, "wide.masked.gather");
+  } else {
+    VectorBuilder VBuilder(Builder);
+    VBuilder.setEVL(EVL).setMask(Mask);
+    NewLI = cast<CallInst>(VBuilder.createVectorInstruction(
+        Instruction::Load, DataTy, Addr, "vp.op.load"));
+  }
+  NewLI->addParamAttr(
+      0, Attribute::getWithAlignment(NewLI->getContext(), Alignment));
+
+  State.addMetadata(NewLI, LI);
+  State.set(this, NewLI, 0);
+}
+
 void VPWidenStoreRecipe::execute(VPTransformState &State) {
   auto *SI = cast<StoreInst>(&Ingredient);
 
@@ -9448,7 +9423,6 @@ void VPWidenStoreRecipe::execute(VPTransformState &State) {
 
     Value *StoredVal = State.get(StoredVPValue, Part);
     if (isReverse()) {
-      assert(!State.EVL && "reversing not yet implemented with EVL");
       // If we store to reverse consecutive memory locations, then we need
       // to reverse the order of elements in the stored value.
       StoredVal = Builder.CreateVectorReverse(StoredVal, "reverse");
@@ -9456,23 +9430,7 @@ void VPWidenStoreRecipe::execute(VPTransformState &State) {
       // another expression. So don't call resetVectorValue(StoredVal).
     }
     // TODO: split this into several classes for better design.
-    if (State.EVL) {
-      assert(State.UF == 1 && "Expected only UF == 1 when vectorizing with "
-                              "explicit vector length.");
-      assert(cast<VPInstruction>(State.EVL)->getOpcode() ==
-                 VPInstruction::ExplicitVectorLength &&
-             "EVL must be VPInstruction::ExplicitVectorLength.");
-      Value *EVL = State.get(State.EVL, VPIteration(0, 0));
-      // If EVL is not nullptr, then EVL must be a valid value set during plan
-      // creation, possibly default value = whole vector register length. EVL
-      // is created only if TTI prefers predicated vectorization, thus if EVL
-      // is not nullptr it also implies preference for predicated
-      // vectorization.
-      // FIXME: Support reverse store after vp_reverse is added.
-      NewSI = lowerStoreUsingVectorIntrinsics(
-          Builder, State.get(getAddr(), Part, !CreateScatter), StoredVal,
-          CreateScatter, Mask, EVL, Alignment);
-    } else if (CreateScatter) {
+    if (CreateScatter) {
       Value *VectorGep = State.get(getAddr(), Part);
       NewSI =
           Builder.CreateMaskedScatter(StoredVal, VectorGep, Alignment, Mask);
@@ -9487,6 +9445,45 @@ void VPWidenStoreRecipe::execute(VPTransformState &State) {
   }
 }
 
+void VPWidenVPStoreRecipe::execute(VPTransformState &State) {
+  assert(State.UF == 1 && "Expected only UF == 1 when vectorizing with "
+                          "explicit vector length.");
+  // FIXME: Support reverse loading after vp_reverse is added.
+  assert(!isReverse() && "Reverse store are not implemented yet.");
+
+  auto *SI = cast<StoreInst>(&Ingredient);
+
+  VPValue *StoredValue = getStoredValue();
+  bool CreateScatter = !isConsecutive();
+  const Align Alignment = getLoadStoreAlignment(&Ingredient);
+
+  auto &Builder = State.Builder;
+  State.setDebugLocFrom(getDebugLoc());
+
+  CallInst *NewSI = nullptr;
+  Value *StoredVal = State.get(StoredValue, 0);
+  Value *EVL = State.get(getEVL(), VPIteration(0, 0));
+  // FIXME: Support reverse store after vp_reverse is added.
+  Value *Mask =
+      getMask() ? State.get(getMask(), 0)
+                : Mask = Builder.CreateVectorSplat(State.VF, Builder.getTrue());
+  Value *Addr = State.get(getAddr(), 0, !CreateScatter);
+  if (CreateScatter) {
+    NewSI = Builder.CreateIntrinsic(Type::getVoidTy(EVL->getContext()),
+                                    Intrinsic::vp_scatter,
+                                    {StoredVal, Addr, Mask, EVL});
+  } else {
+    VectorBuilder VBuilder(Builder);
+    VBuilder.setEVL(EVL).setMask(Mask);
+    NewSI = cast<CallInst>(VBuilder.createVectorInstruction(
+        Instruction::Store, Type::getVoidTy(EVL->getContext()),
+        {StoredVal, Addr}));
+  }
+  NewSI->addParamAttr(
+      1, Attribute::getWithAlignment(NewSI->getContext(), Alignment));
+
+  State.addMetadata(NewSI, SI);
+}
 // Determine how to lower the scalar epilogue, which depends on 1) optimising
 // for minimum code-size, 2) predicate compiler options, 3) loop hints forcing
 // predication, and 4) a TTI hook that analyses whether the loop is suitable
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index 334b10e2e5d097..a0ecac0fae76a8 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -242,15 +242,6 @@ struct VPTransformState {
   ElementCount VF;
   unsigned UF;
 
-  /// If EVL (Explicit Vector Length) is not nullptr, then EVL must be a valid
-  /// value set during plan transformation, possibly a default value = whole
-  /// vector register length. EVL is created only if TTI prefers predicated
-  /// vectorization, thus if EVL is not nullptr it also implies preference for
-  /// predicated vectorization.
-  /// TODO: this is a temporarily solution, the EVL must be explicitly used by
-  /// the recipes and must be removed here.
-  VPValue *EVL = nullptr;
-
   /// Hold the indices to generate specific scalar instructions. Null indicates
   /// that all instances are to be generated, using either scalar or vector
   /// instructions.
@@ -877,6 +868,8 @@ class VPSingleDefRecipe : public VPRecipeBase, public VPValue {
     case VPRecipeBase::VPBranchOnMaskSC:
     case VPRecipeBase::VPWidenLoadSC:
     case VPRecipeBase::VPWidenStoreSC:
+    case VPRecipeBase::VPWidenVPLoadSC:
+    case VPRecipeBase::VPWidenVPStoreSC:
       // TODO: Widened stores don't define a value, but widened loads do. Split
       // the recipes to be able to make widened loads VPSingleDefRecipes.
       return false;
@@ -2318,11 +2311,15 @@ class VPWidenMemoryRecipe : public VPRecipeBase {
   }
 
 public:
-  VPWidenMemoryRecipe *clone() override = 0;
+  VPWidenMemoryRecipe *clone() override {
+    llvm_unreachable("cloning not supported");
+  }
 
   static inline bool classof(const VPRecipeBase *R) {
-    return R->getVPDefID() == VPDef::VPWidenLoadSC ||
-           R->getVPDefID() == VPDef::VPWidenStoreSC;
+    return R->getVPDefID() == VPRecipeBase::VPWidenLoadSC ||
+           R->getVPDefID() == VPRecipeBase::VPWidenStoreSC ||
+           R->getVPDefID() == VPRecipeBase::VPWidenVPLoadSC ||
+           R->getVPDefID() == VPRecipeBase::VPWidenVPStoreSC;
   }
 
   static inline bool classof(const VPUser *U) {
@@ -2390,13 +2387,49 @@ struct VPWidenLoadRecipe final : public VPWidenMemoryRecipe, public VPValue {
   bool onlyFirstLaneUsed(const VPValue *Op) const override {
     assert(is_contained(operands(), Op) &&
            "Op must be an operand of the recipe");
-
-    // Widened, consecutive loads operations only demand the first lane of
-    // their address.
+    // Widened, consecutive memory operations only demand the first lane of
+    // their address, unless the same operand is also stored. That latter can
+    // happen with opaque pointers.
     return Op == getAddr() && isConsecutive();
   }
 };
 
+/// A recipe for widening load operations with vector-predication intrinsics,
+/// using the address to load from, the explicit vector length and an optional
+/// mask.
+struct VPWidenVPLoadRecipe final : public VPWidenMemoryRecipe, public VPValue {
+  VPWidenVPLoadRecipe(VPWidenLoadRecipe *L, VPValue *EVL, VPValue *Mask)
+      : VPWidenMemoryRecipe(
+            VPDef::VPWidenVPLoadSC, *cast<LoadInst>(&L->getIngredient()),
+            {L->getAddr(), EVL}, L->isConsecutive(), false, L->getDebugLoc()),
+        VPValue(this, &getIngredient()) {
+    setMask(Mask);
+  }
+
+  VP_CLASSOF_IMPL(VPDef::VPWidenVPLoadSC)
+
+  /// Return the EVL operand.
+  VPValue *getEVL() const { return getOperand(1); }
+
+  /// Generate the wide load/store.
+  void execute(VPTransformState &State) override;
+
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+  /// Print the recipe.
+  void print(raw_ostream &O, const Twine &Indent,
+             VPSlotTracker &SlotTracker) const override;
+#endif
+
+  /// Returns true if the recipe only uses the first lane of operand \p Op.
+  bool onlyFirstLaneUsed(const VPValue *Op) const override {
+    assert(is_contained(operands(), Op) &&
+           "Op must be an operand of the recipe");
+    // Widened loads only demand the first lane of EVL and consecutive loads
+    // only demand the first lane of their address.
+    return Op == getEVL() || (Op == getAddr() && isConsecutive());
+  }
+};
+
 /// A recipe for widening store operations, using the stored value, the address
 /// to store to and an optional mask.
 struct VPWidenStoreRecipe final : public VPWidenMemoryRecipe {
@@ -2436,6 +2469,51 @@ struct VPWidenStoreRecipe final : public VPWidenMemoryRecipe {
     return Op == getAddr() && isConsecutive() && Op != getStoredValue();
   }
 };
+
+/// A recipe for widening store operations with vector-predication intrinsics,
+/// using the value to store, the address to store to , the explicit vector
+/// length and an optional mask.
+struct VPWidenVPStoreRecipe final : public VPWidenMemoryRecipe {
+  VPWidenVPStoreRecipe(VPWidenStoreRecipe *S, VPValue *EVL, VPValue *Mask)
+      : VPWidenMemoryRecipe(VPDef::VPWidenVPStoreSC,
+                            *cast<StoreInst>(&S->getIngredient()),
+                            {S->getAddr(), S->getStoredValue(), EVL},
+                            S->isConsecutive(), false, S->getDebugLoc()) {
+    setMask(Mask);
+  }
+
+  VP_CLASSOF_IMPL(VPDef::VPWidenVPStoreSC)
+
+  /// Return the address accessed by this recipe.
+  VPValue *getStoredValue() const { return getOperand(1); }
+
+  /// Return the EVL operand.
+  VPValue *getEVL() const { return getOperand(2); }
+
+  /// Generate the wide load/store.
+  void execute(VPTransformState &State) override;
+
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+  /// Print the recipe.
+  void print(raw_ostream &O, const Twine &Indent,
+             VPSlotTracker &SlotTracker) const override;
+#endif
+
+  /// Returns true if the recipe only uses the first lane of operand \p Op.
+  bool onlyFirstLaneUsed(const VPValue *Op) const override {
+    assert(is_contained(operands(), Op) &&
+           "Op must be an operand of the recipe");
+    if (Op == getEVL()) {
+      assert(getStoredValue() != Op && "unexpected store of EVL");
+      return true;
+    }
+    // Widened, consecutive memory operations only demand the first lane of
+    // their address, unless the same operand is also stored. That latter can
+    // happen with opaque pointers.
+    return Op == getAddr() && isConsecutive() && Op != getStoredValue();
+  }
+};
+
 /// Recipe to expand a SCEV expression.
 class VPExpandSCEVRecipe : public VPSingleDefRecipe {
   const SCEV *Expr;
diff --git a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
index 130fb04f586e75..d820a65299cf4f 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
@@ -109,7 +109,7 @@ Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPWidenCallRecipe *R) {
 }
 
 Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPWidenMemoryRecipe *R) {
-  assert(isa<VPWidenLoadRecipe>(R) &&
+  assert((isa<VPWidenLoadRecipe>(R) || isa<VPWidenVPLoadRecipe>(R)) &&
          "Store recipes should not define any values");
   return cast<LoadInst>(&R->getIngredient())->getType();
 }
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 78932643c81fa3..6f3a818b74a058 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -48,6 +48,7 @@ bool VPRecipeBase::mayWriteToMemory() const {
   case VPInterleaveSC:
     return cast<VPInterleaveRecipe>(this)->getNumStoreOperands() > 0;
   case VPWidenStoreSC:
+  case VPWidenVPStoreSC:
     return true;
   case VPReplicateSC:
   case VPWidenCallSC:
@@ -65,6 +66,7 @@ bool VPRecipeBase::mayWriteToMemory() const {
   case VPWidenIntOrFpInductionSC:
   case VPWidenLoadSC:
   case VPWidenPHISC:
+  case VPWidenVPLoadSC:
   case VPWidenSC:
   case VPWidenSelectSC: {
     const Instruction *I =
@@ -82,6 +84,7 @@ bool VPRecipeBase::mayWriteToMemory() const {
 bool VPRecipeBase::mayReadFromMemory() const {
   switch (getVPDefID()) {
   case VPWidenLoadSC:
+  case VPWidenVPLoadSC:
     return true;
   case VPReplicateSC:
   case VPWidenCallSC:
@@ -91,6 +94,7 @@ bool VPRecipeBase::mayReadFromMemory() const {
   case VPPredInstPHISC:
   case VPScalarIVStepsSC:
   case VPWidenStoreSC:
+  case VPWidenVPStoreSC:
     return false;
   case VPBlendSC:
   case VPReductionSC:
@@ -157,6 +161,8 @@ bool VPRecipeBase::mayHaveSideEffects() const {
     return mayWriteToMemory();
   case VPWidenLoadSC:
   case VPWidenStoreSC:
+  case VPWidenVPLoadSC:
+  case VPWidenVPStoreSC:
     assert(
         cast<VPWidenMemoryRecipe>(this)->getIngredient().mayHaveSideEffects() ==
             mayWriteToMemory() &&
@@ -411,8 +417,6 @@ Value *VPInstruction::generatePerPart(VPTransformState &State, unsigned Part) {
     Value *TripCount = State.get(getOperand(1), VPIteration(0, 0));
     Value *AVL = State.Builder.CreateSub(TripCount, Index);
     Value *EVL = GetEVL(State, AVL);
-    assert(!State.EVL && "multiple EVL recipes");
-    State.EVL = this;
     return EVL;
   }
   case VPInstruction::CanonicalIVIncrementForPart: {
@@ -1778,11 +1782,25 @@ void VPWidenLoadRecipe::print(raw_ostream &O, const Twine &Indent,
   printOperands(O, SlotTracker);
 }
 
+void VPWidenVPLoadRecipe::print(raw_ostream &O, const Twine &Indent,
+                                VPSlotTracker &SlotTracker) const {
+  O << Indent << "WIDEN ";
+  printAsOperand(O, SlotTracker);
+  O << " = vp.load ";
+  printOperands(O, SlotTracker);
+}
+
 void VPWidenStoreRecipe::print(raw_ostream &O, const Twine &Indent,
                                VPSlotTracker &SlotTracker) const {
   O << Indent << "WIDEN store ";
   printOperands(O, SlotTracker);
 }
+
+void VPWidenVPStoreRecipe::print(raw_ostream &O, const Twine &Indent,
+                                 VPSlotTracker &SlotTracker) const {
+  O << Indent << "WIDEN vp.store ";
+  printOperands(O, SlotTracker);
+}
 #endif
 
 void VPCanonicalIVPHIRecipe::execute(VPTransformState &State) {
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index 901ecd10c69d8f..7b6634499c6bf8 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -1203,43 +1203,48 @@ static VPActiveLaneMaskPHIRecipe *addVPLaneMaskPhiAndUpdateExitBranch(
   return LaneMaskPhi;
 }
 
-/// Replaces (ICMP_ULE, WideCanonicalIV, backedge-taken-count) pattern using
-/// the given \p Idiom.
-static void
-replaceHeaderPredicateWith(VPlan &Plan, VPValue &Idiom,
-                           function_ref<bool(VPUser &, unsigned)> Cond = {}) {
+/// Apply \p Fn to all VPInstructions matching the header mask (ICMP_ULE,
+/// WideCanonicalIV, backedge-taken-count) pattern
+static void forAllHeaderPredicates(VPlan &Plan,
+                                   function_ref<void(VPInstruction &)> Fn) {
+  SmallVector<VPValue *> WideCanonicalIVs;
   auto *FoundWidenCanonicalIVUser =
       find_if(Plan.getCanonicalIV()->users(),
               [](VPUser *U) { return isa<VPWidenCanonicalIVRecipe>(U); });
-  if (FoundWidenCanonicalIVUser == Plan.getCanonicalIV()->users().end())
-    return;
-  auto *WideCanonicalIV =
-      cast<VPWidenCanonicalIVRecipe>(*FoundWidenCanonicalIVUser);
-  // Walk users of WideCanonicalIV and replace all compares of the form
+  if (FoundWidenCanonicalIVUser != Plan.getCanonicalIV()->users().end()) {
+    auto *WideCanonicalIV =
+        cast<VPWidenCanonicalIVRecipe>(*FoundWidenCanonicalIVUser);
+    WideCanonicalIVs.push_back(WideCanonicalIV);
+  }
+
+  // Also include VPWidenIntOrFpInductionRecipes that represent a widened
+  // version of the canonical induction.
+  VPBasicBlock *HeaderVPBB = Plan.getVectorLoopRegion()->getEntryBasicBlock();
+  for (VPRecipeBase &Phi : HeaderVPBB->phis()) {
+    auto *WidenOriginalIV = dyn_cast<VPWidenIntOrFpInductionRecipe>(&Phi);
+    if (WidenOriginalIV && WidenOriginalIV->isCanonical())
+      WideCanonicalIVs.push_back(WidenOriginalIV);
+  }
+
+  // Walk users of wide canonical IVs and replace all compares of the form
   // (ICMP_ULE, WideCanonicalIV, backedge-taken-count) with
   // the given idiom VPValue.
   VPValue *BTC = Plan.getOrCreateBackedgeTakenCount();
-  for (VPUser *U : SmallVector<VPUser *>(WideCanonicalIV->users())) {
-    auto *CompareToReplace = dyn_cast<VPInstruction>(U);
-    if (!CompareToReplace ||
-        CompareToReplace->getOpcode() != Instruction::ICmp ||
-        CompareToReplace->getPredicate() != CmpInst::ICMP_ULE ||
-        CompareToReplace->getOperand(1) != BTC)
-      continue;
+  for (auto *Wide : WideCanonicalIVs) {
+    for (VPUser *U : SmallVector<VPUser *>(Wide->users())) {
+      auto *CompareToReplace = dyn_cast<VPInstruction>(U);
+      if (!CompareToReplace ||
+          CompareToReplace->getOpcode() != Instruction::ICmp ||
+          CompareToReplace->getPredicate() != CmpInst::ICMP_ULE ||
+          CompareToReplace->getOperand(1) != BTC)
+        continue;
 
-    assert(CompareToReplace->getOperand(0) == WideCanonicalIV &&
-           "WidenCanonicalIV must be the first operand of the compare");
-    if (Cond) {
-      CompareToReplace->replaceUsesWithIf(&Idiom, Cond);
-      if (!CompareToReplace->getNumUsers())
-        CompareToReplace->eraseFromParent();
-    } else {
-      CompareToReplace->replaceAllUsesWith(&Idiom);
-      CompareToReplace->eraseFromParent();
+      assert(CompareToReplace->getOperand(0) == Wide &&
+             "WidenCanonicalIV must be the first operand of the compare");
+      Fn(*CompareToReplace);
+      recursivelyDeleteDeadRecipes(CompareToReplace);
     }
   }
-  if (!WideCanonicalIV->getNumUsers())
-    WideCanonicalIV->eraseFromParent();
 }
 
 void VPlanTransforms::addActiveLaneMask(
@@ -1271,7 +1276,8 @@ void VPlanTransforms::addActiveLaneMask(
   // Walk users of WideCanonicalIV and replace all compares of the form
   // (ICMP_ULE, WideCanonicalIV, backedge-taken-count) with an
   // active-lane-mask.
-  replaceHeaderPredicateWith(Plan, *LaneMask);
+  forAllHeaderPredicates(
+      Plan, [LaneMask](VPInstruction &I) { I.replaceAllUsesWith(LaneMask); });
 }
 
 /// Add a VPEVLBasedIVPHIRecipe and related recipes to \p Plan and
@@ -1301,17 +1307,7 @@ void VPlanTransforms::addExplicitVectorLength(VPlan &Plan) {
   auto *CanonicalIVPHI = Plan.getCanonicalIV();
   VPValue *StartV = CanonicalIVPHI->getStartValue();
 
-  // TODO: revisit this and try to remove the mask operand.
-  // Walk VPWidenMemoryInstructionRecipe users of WideCanonicalIV and replace
-  // all compares of the form (ICMP_ULE, WideCanonicalIV, backedge-taken-count),
-  // used as mask in VPWidenMemoryInstructionRecipe, with an all-true-mask.
-  Value *TrueMask =
-      ConstantInt::getTrue(CanonicalIVPHI->getScalarType()->getContext());
-  VPValue *VPTrueMask = Plan.getOrAddLiveIn(TrueMask);
-  replaceHeaderPredicateWith(Plan, *VPTrueMask, [](VPUser &U, unsigned) {
-    return isa<VPWidenMemoryRecipe>(U);
-  });
-  // Now create the ExplicitVectorLengthPhi recipe in the main loop.
+  // Create the ExplicitVectorLengthPhi recipe in the main loop.
   auto *EVLPhi = new VPEVLBasedIVPHIRecipe(StartV, DebugLoc());
   EVLPhi->insertAfter(CanonicalIVPHI);
   auto *VPEVL = new VPInstruction(VPInstruction::ExplicitVectorLength,
@@ -1336,6 +1332,30 @@ void VPlanTransforms::addExplicitVectorLength(VPlan &Plan) {
   NextEVLIV->insertBefore(CanonicalIVIncrement);
   EVLPhi->addOperand(NextEVLIV);
 
+  forAllHeaderPredicates(Plan, [VPEVL](VPInstruction &Mask) {
+    for (VPUser *U : collectUsersRecursively(&Mask)) {
+      auto *MemR = dyn_cast<VPWidenMemoryRecipe>(U);
+      if (!MemR)
+        continue;
+      assert(!MemR->isReverse() &&
+             "Reversed memory operations not supported yet.");
+      VPValue *OrigMask = MemR->getMask();
+      assert(OrigMask && "Unmasked widen memory recipe when folding tail");
+      VPValue *NewMask = &Mask == OrigMask ? nullptr : OrigMask;
+      if (auto *L = dyn_cast<VPWidenLoadRecipe>(MemR)) {
+        auto *N = new VPWidenVPLoadRecipe(L, VPEVL, NewMask);
+        N->insertBefore(L);
+        L->replaceAllUsesWith(N);
+        L->eraseFromParent();
+      } else if (auto *S = dyn_cast<VPWidenStoreRecipe>(MemR)) {
+        auto *N = new VPWidenVPStoreRecipe(S, VPEVL, NewMask);
+        N->insertBefore(S);
+        S->eraseFromParent();
+      } else {
+        llvm_unreachable("unsupported recipe");
+      }
+    }
+  });
   // Replace all uses of VPCanonicalIVPHIRecipe by
   // VPEVLBasedIVPHIRecipe except for the canonical IV increment.
   CanonicalIVPHI->replaceAllUsesWith(EVLPhi);
diff --git a/llvm/lib/Transforms/Vectorize/VPlanValue.h b/llvm/lib/Transforms/Vectorize/VPlanValue.h
index 0bbc7ffb4a2fe0..ae9bf2b2af72df 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanValue.h
+++ b/llvm/lib/Transforms/Vectorize/VPlanValue.h
@@ -358,6 +358,8 @@ class VPDef {
     VPWidenGEPSC,
     VPWidenLoadSC,
     VPWidenStoreSC,
+    VPWidenVPLoadSC,
+    VPWidenVPStoreSC,
     VPWidenSC,
     VPWidenSelectSC,
     VPBlendSC,
diff --git a/llvm/test/Transforms/LoopVectorize/RISCV/vectorize-force-tail-with-evl-gather-scatter.ll b/llvm/test/Transforms/LoopVectorize/RISCV/vectorize-force-tail-with-evl-gather-scatter.ll
index 835ff375688173..ae01bdd3711060 100644
--- a/llvm/test/Transforms/LoopVectorize/RISCV/vectorize-force-tail-with-evl-gather-scatter.ll
+++ b/llvm/test/Transforms/LoopVectorize/RISCV/vectorize-force-tail-with-evl-gather-scatter.ll
@@ -26,7 +26,6 @@ define void @gather_scatter(ptr noalias %in, ptr noalias %out, ptr noalias %inde
 ; IF-EVL-NEXT:    [[N_RND_UP:%.*]] = add i64 [[N]], [[TMP8]]
 ; IF-EVL-NEXT:    [[N_MOD_VF:%.*]] = urem i64 [[N_RND_UP]], [[TMP5]]
 ; IF-EVL-NEXT:    [[N_VEC:%.*]] = sub i64 [[N_RND_UP]], [[N_MOD_VF]]
-; IF-EVL-NEXT:    [[TRIP_COUNT_MINUS_1:%.*]] = sub i64 [[N]], 1
 ; IF-EVL-NEXT:    [[TMP9:%.*]] = call i64 @llvm.vscale.i64()
 ; IF-EVL-NEXT:    [[TMP10:%.*]] = mul i64 [[TMP9]], 2
 ; IF-EVL-NEXT:    [[TMP11:%.*]] = call <vscale x 2 x i64> @llvm.experimental.stepvector.nxv2i64()
@@ -36,9 +35,7 @@ define void @gather_scatter(ptr noalias %in, ptr noalias %out, ptr noalias %inde
 ; IF-EVL-NEXT:    [[TMP14:%.*]] = call i64 @llvm.vscale.i64()
 ; IF-EVL-NEXT:    [[TMP15:%.*]] = mul i64 [[TMP14]], 2
 ; IF-EVL-NEXT:    [[TMP16:%.*]] = mul i64 1, [[TMP15]]
-; IF-EVL-NEXT:    [[DOTSPLATINSERT:%.*]] = insertelement <vscale x 2 x i64> poison, i64 [[TMP16]], i64 0
-; IF-EVL-NEXT:    [[DOTSPLAT:%.*]] = shufflevector <vscale x 2 x i64> [[DOTSPLATINSERT]], <vscale x 2 x i64> poison, <vscale x 2 x i32> zeroinitializer
-; IF-EVL-NEXT:    [[BROADCAST_SPLATINSERT:%.*]] = insertelement <vscale x 2 x i64> poison, i64 [[TRIP_COUNT_MINUS_1]], i64 0
+; IF-EVL-NEXT:    [[BROADCAST_SPLATINSERT:%.*]] = insertelement <vscale x 2 x i64> poison, i64 [[TMP16]], i64 0
 ; IF-EVL-NEXT:    [[BROADCAST_SPLAT:%.*]] = shufflevector <vscale x 2 x i64> [[BROADCAST_SPLATINSERT]], <vscale x 2 x i64> poison, <vscale x 2 x i32> zeroinitializer
 ; IF-EVL-NEXT:    br label [[VECTOR_BODY:%.*]]
 ; IF-EVL:       vector.body:
@@ -47,17 +44,16 @@ define void @gather_scatter(ptr noalias %in, ptr noalias %out, ptr noalias %inde
 ; IF-EVL-NEXT:    [[VEC_IND:%.*]] = phi <vscale x 2 x i64> [ [[INDUCTION]], [[VECTOR_PH]] ], [ [[VEC_IND_NEXT:%.*]], [[VECTOR_BODY]] ]
 ; IF-EVL-NEXT:    [[TMP17:%.*]] = sub i64 [[N]], [[EVL_BASED_IV]]
 ; IF-EVL-NEXT:    [[TMP18:%.*]] = call i32 @llvm.experimental.get.vector.length.i64(i64 [[TMP17]], i32 2, i1 true)
-; IF-EVL-NEXT:    [[TMP19:%.*]] = icmp ule <vscale x 2 x i64> [[VEC_IND]], [[BROADCAST_SPLAT]]
 ; IF-EVL-NEXT:    [[TMP20:%.*]] = getelementptr inbounds i32, ptr [[INDEX:%.*]], <vscale x 2 x i64> [[VEC_IND]]
-; IF-EVL-NEXT:    [[WIDE_MASKED_GATHER:%.*]] = call <vscale x 2 x i64> @llvm.vp.gather.nxv2i64.nxv2p0(<vscale x 2 x ptr> align 8 [[TMP20]], <vscale x 2 x i1> [[TMP19]], i32 [[TMP18]])
+; IF-EVL-NEXT:    [[WIDE_MASKED_GATHER:%.*]] = call <vscale x 2 x i64> @llvm.vp.gather.nxv2i64.nxv2p0(<vscale x 2 x ptr> align 8 [[TMP20]], <vscale x 2 x i1> shufflevector (<vscale x 2 x i1> insertelement (<vscale x 2 x i1> poison, i1 true, i64 0), <vscale x 2 x i1> poison, <vscale x 2 x i32> zeroinitializer), i32 [[TMP18]])
 ; IF-EVL-NEXT:    [[TMP21:%.*]] = getelementptr inbounds float, ptr [[IN:%.*]], <vscale x 2 x i64> [[WIDE_MASKED_GATHER]]
-; IF-EVL-NEXT:    [[WIDE_MASKED_GATHER2:%.*]] = call <vscale x 2 x float> @llvm.vp.gather.nxv2f32.nxv2p0(<vscale x 2 x ptr> align 4 [[TMP21]], <vscale x 2 x i1> [[TMP19]], i32 [[TMP18]])
+; IF-EVL-NEXT:    [[WIDE_MASKED_GATHER2:%.*]] = call <vscale x 2 x float> @llvm.vp.gather.nxv2f32.nxv2p0(<vscale x 2 x ptr> align 4 [[TMP21]], <vscale x 2 x i1> shufflevector (<vscale x 2 x i1> insertelement (<vscale x 2 x i1> poison, i1 true, i64 0), <vscale x 2 x i1> poison, <vscale x 2 x i32> zeroinitializer), i32 [[TMP18]])
 ; IF-EVL-NEXT:    [[TMP22:%.*]] = getelementptr inbounds float, ptr [[OUT:%.*]], <vscale x 2 x i64> [[WIDE_MASKED_GATHER]]
-; IF-EVL-NEXT:    call void @llvm.vp.scatter.nxv2f32.nxv2p0(<vscale x 2 x float> [[WIDE_MASKED_GATHER2]], <vscale x 2 x ptr> align 4 [[TMP22]], <vscale x 2 x i1> [[TMP19]], i32 [[TMP18]])
+; IF-EVL-NEXT:    call void @llvm.vp.scatter.nxv2f32.nxv2p0(<vscale x 2 x float> [[WIDE_MASKED_GATHER2]], <vscale x 2 x ptr> align 4 [[TMP22]], <vscale x 2 x i1> shufflevector (<vscale x 2 x i1> insertelement (<vscale x 2 x i1> poison, i1 true, i64 0), <vscale x 2 x i1> poison, <vscale x 2 x i32> zeroinitializer), i32 [[TMP18]])
 ; IF-EVL-NEXT:    [[TMP23:%.*]] = zext i32 [[TMP18]] to i64
 ; IF-EVL-NEXT:    [[INDEX_EVL_NEXT]] = add i64 [[TMP23]], [[EVL_BASED_IV]]
 ; IF-EVL-NEXT:    [[INDEX_NEXT]] = add i64 [[INDEX1]], [[TMP10]]
-; IF-EVL-NEXT:    [[VEC_IND_NEXT]] = add <vscale x 2 x i64> [[VEC_IND]], [[DOTSPLAT]]
+; IF-EVL-NEXT:    [[VEC_IND_NEXT]] = add <vscale x 2 x i64> [[VEC_IND]], [[BROADCAST_SPLAT]]
 ; IF-EVL-NEXT:    [[TMP24:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
 ; IF-EVL-NEXT:    br i1 [[TMP24]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
 ; IF-EVL:       middle.block:
diff --git a/llvm/test/Transforms/LoopVectorize/RISCV/vplan-vp-intrinsics.ll b/llvm/test/Transforms/LoopVectorize/RISCV/vplan-vp-intrinsics.ll
index 72b881bd44c768..8caa9368bfde18 100644
--- a/llvm/test/Transforms/LoopVectorize/RISCV/vplan-vp-intrinsics.ll
+++ b/llvm/test/Transforms/LoopVectorize/RISCV/vplan-vp-intrinsics.ll
@@ -27,14 +27,14 @@ define void @foo(ptr noalias %a, ptr noalias %b, ptr noalias %c, i64 %N) {
 ; IF-EVL-NEXT:    vp<[[ST:%[0-9]+]]> = SCALAR-STEPS vp<[[EVL_PHI]]>, ir<1>
 ; IF-EVL-NEXT:    CLONE ir<[[GEP1:%.+]]> = getelementptr inbounds ir<%b>, vp<[[ST]]>
 ; IF-EVL-NEXT:    vp<[[PTR1:%[0-9]+]]> = vector-pointer ir<[[GEP1]]>
-; IF-EVL-NEXT:    WIDEN ir<[[LD1:%.+]]> = load vp<[[PTR1]]>, ir<true>
+; IF-EVL-NEXT:    WIDEN ir<[[LD1:%.+]]> = vp.load vp<[[PTR1]]>, vp<[[EVL]]>
 ; IF-EVL-NEXT:    CLONE ir<[[GEP2:%.+]]> = getelementptr inbounds ir<%c>, vp<[[ST]]>
 ; IF-EVL-NEXT:    vp<[[PTR2:%[0-9]+]]> = vector-pointer ir<[[GEP2]]>
-; IF-EVL-NEXT:    WIDEN ir<[[LD2:%.+]]> = load vp<[[PTR2]]>, ir<true>
+; IF-EVL-NEXT:    WIDEN ir<[[LD2:%.+]]> = vp.load vp<[[PTR2]]>, vp<[[EVL]]>
 ; IF-EVL-NEXT:    WIDEN ir<[[ADD:%.+]]> = add nsw ir<[[LD2]]>, ir<[[LD1]]>
 ; IF-EVL-NEXT:    CLONE ir<[[GEP3:%.+]]> = getelementptr inbounds ir<%a>, vp<[[ST]]>
 ; IF-EVL-NEXT:    vp<[[PTR3:%[0-9]+]]> = vector-pointer ir<[[GEP3]]>
-; IF-EVL-NEXT:    WIDEN store vp<[[PTR3]]>, ir<[[ADD]]>, ir<true>
+; IF-EVL-NEXT:    WIDEN vp.store vp<[[PTR3]]>, ir<[[ADD]]>, vp<[[EVL]]>
 ; IF-EVL-NEXT:    SCALAR-CAST vp<[[CAST:%[0-9]+]]> = zext vp<[[EVL]]> to i64
 ; IF-EVL-NEXT:    EMIT vp<[[IV_NEXT]]> = add vp<[[CAST]]>, vp<[[EVL_PHI]]>
 ; IF-EVL-NEXT:    EMIT vp<[[IV_NEXT_EXIT:%[0-9]+]]> = add vp<[[IV]]>, vp<[[VFUF]]>

>From cb979499ed96b116b5a6ca8f367c698468298773 Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Thu, 18 Apr 2024 16:02:56 +0100
Subject: [PATCH 2/2] !fixup address latest comments, thanks!

---
 .../Transforms/Vectorize/LoopVectorize.cpp    | 44 ++++++---------
 llvm/lib/Transforms/Vectorize/VPlan.h         | 35 ++++++------
 .../Transforms/Vectorize/VPlanAnalysis.cpp    |  2 +-
 .../lib/Transforms/Vectorize/VPlanRecipes.cpp | 20 +++----
 .../Transforms/Vectorize/VPlanTransforms.cpp  | 55 +++++++++++--------
 llvm/lib/Transforms/Vectorize/VPlanValue.h    |  4 +-
 6 files changed, 80 insertions(+), 80 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 6662682825849e..38f0ebd04a000d 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -9338,32 +9338,27 @@ void VPWidenLoadRecipe::execute(VPTransformState &State) {
         Mask = Builder.CreateVectorReverse(Mask, "reverse");
     }
 
+    Value *Addr = State.get(getAddr(), Part, /*IsScalar*/ !CreateGather);
     if (CreateGather) {
-      Value *VectorGep = State.get(getAddr(), Part);
-      NewLI = Builder.CreateMaskedGather(DataTy, VectorGep, Alignment, Mask,
-                                         nullptr, "wide.masked.gather");
+      NewLI = Builder.CreateMaskedGather(DataTy, Addr, Alignment, Mask, nullptr,
+                                         "wide.masked.gather");
       State.addMetadata(NewLI, LI);
+    } else if (Mask) {
+      NewLI = Builder.CreateMaskedLoad(DataTy, Addr, Alignment, Mask,
+                                       PoisonValue::get(DataTy),
+                                       "wide.masked.load");
     } else {
-      auto *VecPtr = State.get(getAddr(), Part, /*IsScalar*/ true);
-      if (Mask)
-        NewLI = Builder.CreateMaskedLoad(DataTy, VecPtr, Alignment, Mask,
-                                         PoisonValue::get(DataTy),
-                                         "wide.masked.load");
-      else
-        NewLI =
-            Builder.CreateAlignedLoad(DataTy, VecPtr, Alignment, "wide.load");
-
-      // Add metadata to the load, but setVectorValue to the reverse shuffle.
-      State.addMetadata(NewLI, LI);
-      if (Reverse)
-        NewLI = Builder.CreateVectorReverse(NewLI, "reverse");
+      NewLI = Builder.CreateAlignedLoad(DataTy, Addr, Alignment, "wide.load");
     }
-
+    // Add metadata to the load, but setVectorValue to the reverse shuffle.
+    State.addMetadata(NewLI, LI);
+    if (Reverse)
+      NewLI = Builder.CreateVectorReverse(NewLI, "reverse");
     State.set(this, NewLI, Part);
   }
 }
 
-void VPWidenVPLoadRecipe::execute(VPTransformState &State) {
+void VPWidenEVLLoadRecipe::execute(VPTransformState &State) {
   assert(State.UF == 1 && "Expected only UF == 1 when vectorizing with "
                           "explicit vector length.");
   // FIXME: Support reverse loading after vp_reverse is added.
@@ -9430,23 +9425,20 @@ void VPWidenStoreRecipe::execute(VPTransformState &State) {
       // We don't want to update the value in the map as it might be used in
       // another expression. So don't call resetVectorValue(StoredVal).
     }
-    // TODO: split this into several classes for better design.
+    Value *Addr = State.get(getAddr(), Part, /*IsScalar*/ !CreateScatter);
     if (CreateScatter) {
-      Value *VectorGep = State.get(getAddr(), Part);
-      NewSI =
-          Builder.CreateMaskedScatter(StoredVal, VectorGep, Alignment, Mask);
+      NewSI = Builder.CreateMaskedScatter(StoredVal, Addr, Alignment, Mask);
     } else {
-      auto *VecPtr = State.get(getAddr(), Part, /*IsScalar*/ true);
       if (Mask)
-        NewSI = Builder.CreateMaskedStore(StoredVal, VecPtr, Alignment, Mask);
+        NewSI = Builder.CreateMaskedStore(StoredVal, Addr, Alignment, Mask);
       else
-        NewSI = Builder.CreateAlignedStore(StoredVal, VecPtr, Alignment);
+        NewSI = Builder.CreateAlignedStore(StoredVal, Addr, Alignment);
     }
     State.addMetadata(NewSI, SI);
   }
 }
 
-void VPWidenVPStoreRecipe::execute(VPTransformState &State) {
+void VPWidenEVLStoreRecipe::execute(VPTransformState &State) {
   assert(State.UF == 1 && "Expected only UF == 1 when vectorizing with "
                           "explicit vector length.");
   // FIXME: Support reverse loading after vp_reverse is added.
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index a0ecac0fae76a8..4d0a3ac77dc137 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -868,8 +868,8 @@ class VPSingleDefRecipe : public VPRecipeBase, public VPValue {
     case VPRecipeBase::VPBranchOnMaskSC:
     case VPRecipeBase::VPWidenLoadSC:
     case VPRecipeBase::VPWidenStoreSC:
-    case VPRecipeBase::VPWidenVPLoadSC:
-    case VPRecipeBase::VPWidenVPStoreSC:
+    case VPRecipeBase::VPWidenEVLLoadSC:
+    case VPRecipeBase::VPWidenEVLStoreSC:
       // TODO: Widened stores don't define a value, but widened loads do. Split
       // the recipes to be able to make widened loads VPSingleDefRecipes.
       return false;
@@ -2318,8 +2318,8 @@ class VPWidenMemoryRecipe : public VPRecipeBase {
   static inline bool classof(const VPRecipeBase *R) {
     return R->getVPDefID() == VPRecipeBase::VPWidenLoadSC ||
            R->getVPDefID() == VPRecipeBase::VPWidenStoreSC ||
-           R->getVPDefID() == VPRecipeBase::VPWidenVPLoadSC ||
-           R->getVPDefID() == VPRecipeBase::VPWidenVPStoreSC;
+           R->getVPDefID() == VPRecipeBase::VPWidenEVLLoadSC ||
+           R->getVPDefID() == VPRecipeBase::VPWidenEVLStoreSC;
   }
 
   static inline bool classof(const VPUser *U) {
@@ -2387,9 +2387,8 @@ struct VPWidenLoadRecipe final : public VPWidenMemoryRecipe, public VPValue {
   bool onlyFirstLaneUsed(const VPValue *Op) const override {
     assert(is_contained(operands(), Op) &&
            "Op must be an operand of the recipe");
-    // Widened, consecutive memory operations only demand the first lane of
-    // their address, unless the same operand is also stored. That latter can
-    // happen with opaque pointers.
+    // Widened, consecutive loads operations only demand the first lane of
+    // their address.
     return Op == getAddr() && isConsecutive();
   }
 };
@@ -2397,21 +2396,21 @@ struct VPWidenLoadRecipe final : public VPWidenMemoryRecipe, public VPValue {
 /// A recipe for widening load operations with vector-predication intrinsics,
 /// using the address to load from, the explicit vector length and an optional
 /// mask.
-struct VPWidenVPLoadRecipe final : public VPWidenMemoryRecipe, public VPValue {
-  VPWidenVPLoadRecipe(VPWidenLoadRecipe *L, VPValue *EVL, VPValue *Mask)
+struct VPWidenEVLLoadRecipe final : public VPWidenMemoryRecipe, public VPValue {
+  VPWidenEVLLoadRecipe(VPWidenLoadRecipe *L, VPValue *EVL, VPValue *Mask)
       : VPWidenMemoryRecipe(
-            VPDef::VPWidenVPLoadSC, *cast<LoadInst>(&L->getIngredient()),
+            VPDef::VPWidenEVLLoadSC, *cast<LoadInst>(&L->getIngredient()),
             {L->getAddr(), EVL}, L->isConsecutive(), false, L->getDebugLoc()),
         VPValue(this, &getIngredient()) {
     setMask(Mask);
   }
 
-  VP_CLASSOF_IMPL(VPDef::VPWidenVPLoadSC)
+  VP_CLASSOF_IMPL(VPDef::VPWidenEVLLoadSC)
 
   /// Return the EVL operand.
   VPValue *getEVL() const { return getOperand(1); }
 
-  /// Generate the wide load/store.
+  /// Generate the wide load or gather.
   void execute(VPTransformState &State) override;
 
 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
@@ -2471,18 +2470,18 @@ struct VPWidenStoreRecipe final : public VPWidenMemoryRecipe {
 };
 
 /// A recipe for widening store operations with vector-predication intrinsics,
-/// using the value to store, the address to store to , the explicit vector
+/// using the value to store, the address to store to, the explicit vector
 /// length and an optional mask.
-struct VPWidenVPStoreRecipe final : public VPWidenMemoryRecipe {
-  VPWidenVPStoreRecipe(VPWidenStoreRecipe *S, VPValue *EVL, VPValue *Mask)
-      : VPWidenMemoryRecipe(VPDef::VPWidenVPStoreSC,
+struct VPWidenEVLStoreRecipe final : public VPWidenMemoryRecipe {
+  VPWidenEVLStoreRecipe(VPWidenStoreRecipe *S, VPValue *EVL, VPValue *Mask)
+      : VPWidenMemoryRecipe(VPDef::VPWidenEVLStoreSC,
                             *cast<StoreInst>(&S->getIngredient()),
                             {S->getAddr(), S->getStoredValue(), EVL},
                             S->isConsecutive(), false, S->getDebugLoc()) {
     setMask(Mask);
   }
 
-  VP_CLASSOF_IMPL(VPDef::VPWidenVPStoreSC)
+  VP_CLASSOF_IMPL(VPDef::VPWidenEVLStoreSC)
 
   /// Return the address accessed by this recipe.
   VPValue *getStoredValue() const { return getOperand(1); }
@@ -2490,7 +2489,7 @@ struct VPWidenVPStoreRecipe final : public VPWidenMemoryRecipe {
   /// Return the EVL operand.
   VPValue *getEVL() const { return getOperand(2); }
 
-  /// Generate the wide load/store.
+  /// Generate the wide store or scatter.
   void execute(VPTransformState &State) override;
 
 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
diff --git a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
index d820a65299cf4f..1beed001ce9651 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
@@ -109,7 +109,7 @@ Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPWidenCallRecipe *R) {
 }
 
 Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPWidenMemoryRecipe *R) {
-  assert((isa<VPWidenLoadRecipe>(R) || isa<VPWidenVPLoadRecipe>(R)) &&
+  assert((isa<VPWidenLoadRecipe>(R) || isa<VPWidenEVLLoadRecipe>(R)) &&
          "Store recipes should not define any values");
   return cast<LoadInst>(&R->getIngredient())->getType();
 }
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 6f3a818b74a058..9f3a1b4d43ff1f 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -48,7 +48,7 @@ bool VPRecipeBase::mayWriteToMemory() const {
   case VPInterleaveSC:
     return cast<VPInterleaveRecipe>(this)->getNumStoreOperands() > 0;
   case VPWidenStoreSC:
-  case VPWidenVPStoreSC:
+  case VPWidenEVLStoreSC:
     return true;
   case VPReplicateSC:
   case VPWidenCallSC:
@@ -66,7 +66,7 @@ bool VPRecipeBase::mayWriteToMemory() const {
   case VPWidenIntOrFpInductionSC:
   case VPWidenLoadSC:
   case VPWidenPHISC:
-  case VPWidenVPLoadSC:
+  case VPWidenEVLLoadSC:
   case VPWidenSC:
   case VPWidenSelectSC: {
     const Instruction *I =
@@ -84,7 +84,7 @@ bool VPRecipeBase::mayWriteToMemory() const {
 bool VPRecipeBase::mayReadFromMemory() const {
   switch (getVPDefID()) {
   case VPWidenLoadSC:
-  case VPWidenVPLoadSC:
+  case VPWidenEVLLoadSC:
     return true;
   case VPReplicateSC:
   case VPWidenCallSC:
@@ -94,7 +94,7 @@ bool VPRecipeBase::mayReadFromMemory() const {
   case VPPredInstPHISC:
   case VPScalarIVStepsSC:
   case VPWidenStoreSC:
-  case VPWidenVPStoreSC:
+  case VPWidenEVLStoreSC:
     return false;
   case VPBlendSC:
   case VPReductionSC:
@@ -161,8 +161,8 @@ bool VPRecipeBase::mayHaveSideEffects() const {
     return mayWriteToMemory();
   case VPWidenLoadSC:
   case VPWidenStoreSC:
-  case VPWidenVPLoadSC:
-  case VPWidenVPStoreSC:
+  case VPWidenEVLLoadSC:
+  case VPWidenEVLStoreSC:
     assert(
         cast<VPWidenMemoryRecipe>(this)->getIngredient().mayHaveSideEffects() ==
             mayWriteToMemory() &&
@@ -1782,8 +1782,8 @@ void VPWidenLoadRecipe::print(raw_ostream &O, const Twine &Indent,
   printOperands(O, SlotTracker);
 }
 
-void VPWidenVPLoadRecipe::print(raw_ostream &O, const Twine &Indent,
-                                VPSlotTracker &SlotTracker) const {
+void VPWidenEVLLoadRecipe::print(raw_ostream &O, const Twine &Indent,
+                                 VPSlotTracker &SlotTracker) const {
   O << Indent << "WIDEN ";
   printAsOperand(O, SlotTracker);
   O << " = vp.load ";
@@ -1796,8 +1796,8 @@ void VPWidenStoreRecipe::print(raw_ostream &O, const Twine &Indent,
   printOperands(O, SlotTracker);
 }
 
-void VPWidenVPStoreRecipe::print(raw_ostream &O, const Twine &Indent,
-                                 VPSlotTracker &SlotTracker) const {
+void VPWidenEVLStoreRecipe::print(raw_ostream &O, const Twine &Indent,
+                                  VPSlotTracker &SlotTracker) const {
   O << Indent << "WIDEN vp.store ";
   printOperands(O, SlotTracker);
 }
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index 7b6634499c6bf8..00328a47fbfd9d 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -1203,10 +1203,9 @@ static VPActiveLaneMaskPHIRecipe *addVPLaneMaskPhiAndUpdateExitBranch(
   return LaneMaskPhi;
 }
 
-/// Apply \p Fn to all VPInstructions matching the header mask (ICMP_ULE,
-/// WideCanonicalIV, backedge-taken-count) pattern
-static void forAllHeaderPredicates(VPlan &Plan,
-                                   function_ref<void(VPInstruction &)> Fn) {
+/// Collet all VPValues representing a header mask through the (ICMP_ULE,
+/// WideCanonicalIV, backedge-taken-count) pattern.
+static SmallVector<VPValue *> collectAllHeaderMasks(VPlan &Plan) {
   SmallVector<VPValue *> WideCanonicalIVs;
   auto *FoundWidenCanonicalIVUser =
       find_if(Plan.getCanonicalIV()->users(),
@@ -1214,6 +1213,12 @@ static void forAllHeaderPredicates(VPlan &Plan,
   if (FoundWidenCanonicalIVUser != Plan.getCanonicalIV()->users().end()) {
     auto *WideCanonicalIV =
         cast<VPWidenCanonicalIVRecipe>(*FoundWidenCanonicalIVUser);
+    assert(all_of(Plan.getCanonicalIV()->users(),
+                  [WideCanonicalIV](VPUser *U) {
+                    return !isa<VPWidenCanonicalIVRecipe>(U) ||
+                           U == WideCanonicalIV;
+                  }) &&
+           "Must have a single WideCanonicalIV");
     WideCanonicalIVs.push_back(WideCanonicalIV);
   }
 
@@ -1226,25 +1231,26 @@ static void forAllHeaderPredicates(VPlan &Plan,
       WideCanonicalIVs.push_back(WidenOriginalIV);
   }
 
-  // Walk users of wide canonical IVs and replace all compares of the form
-  // (ICMP_ULE, WideCanonicalIV, backedge-taken-count) with
-  // the given idiom VPValue.
+  // Walk users of wide canonical IVs and apply Fn to all compares of the form
+  // (ICMP_ULE, WideCanonicalIV, backedge-taken-count).
+  SmallVector<VPValue *> HeaderMasks;
   VPValue *BTC = Plan.getOrCreateBackedgeTakenCount();
   for (auto *Wide : WideCanonicalIVs) {
     for (VPUser *U : SmallVector<VPUser *>(Wide->users())) {
-      auto *CompareToReplace = dyn_cast<VPInstruction>(U);
-      if (!CompareToReplace ||
-          CompareToReplace->getOpcode() != Instruction::ICmp ||
-          CompareToReplace->getPredicate() != CmpInst::ICMP_ULE ||
-          CompareToReplace->getOperand(1) != BTC)
+      // TODO: Introduce explicit recipe for header-mask instead of searching
+      // for the header-mask pattern manually.
+      auto *HeaderMask = dyn_cast<VPInstruction>(U);
+      if (!HeaderMask || HeaderMask->getOpcode() != Instruction::ICmp ||
+          HeaderMask->getPredicate() != CmpInst::ICMP_ULE ||
+          HeaderMask->getOperand(1) != BTC)
         continue;
 
-      assert(CompareToReplace->getOperand(0) == Wide &&
+      assert(HeaderMask->getOperand(0) == Wide &&
              "WidenCanonicalIV must be the first operand of the compare");
-      Fn(*CompareToReplace);
-      recursivelyDeleteDeadRecipes(CompareToReplace);
+      HeaderMasks.push_back(HeaderMask);
     }
   }
+  return HeaderMasks;
 }
 
 void VPlanTransforms::addActiveLaneMask(
@@ -1276,8 +1282,10 @@ void VPlanTransforms::addActiveLaneMask(
   // Walk users of WideCanonicalIV and replace all compares of the form
   // (ICMP_ULE, WideCanonicalIV, backedge-taken-count) with an
   // active-lane-mask.
-  forAllHeaderPredicates(
-      Plan, [LaneMask](VPInstruction &I) { I.replaceAllUsesWith(LaneMask); });
+  for (VPValue *HeaderMask : collectAllHeaderMasks(Plan)) {
+    HeaderMask->replaceAllUsesWith(LaneMask);
+    recursivelyDeleteDeadRecipes(HeaderMask);
+  }
 }
 
 /// Add a VPEVLBasedIVPHIRecipe and related recipes to \p Plan and
@@ -1332,8 +1340,8 @@ void VPlanTransforms::addExplicitVectorLength(VPlan &Plan) {
   NextEVLIV->insertBefore(CanonicalIVIncrement);
   EVLPhi->addOperand(NextEVLIV);
 
-  forAllHeaderPredicates(Plan, [VPEVL](VPInstruction &Mask) {
-    for (VPUser *U : collectUsersRecursively(&Mask)) {
+  for (VPValue *HeaderMask : collectAllHeaderMasks(Plan)) {
+    for (VPUser *U : collectUsersRecursively(HeaderMask)) {
       auto *MemR = dyn_cast<VPWidenMemoryRecipe>(U);
       if (!MemR)
         continue;
@@ -1341,21 +1349,22 @@ void VPlanTransforms::addExplicitVectorLength(VPlan &Plan) {
              "Reversed memory operations not supported yet.");
       VPValue *OrigMask = MemR->getMask();
       assert(OrigMask && "Unmasked widen memory recipe when folding tail");
-      VPValue *NewMask = &Mask == OrigMask ? nullptr : OrigMask;
+      VPValue *NewMask = HeaderMask == OrigMask ? nullptr : OrigMask;
       if (auto *L = dyn_cast<VPWidenLoadRecipe>(MemR)) {
-        auto *N = new VPWidenVPLoadRecipe(L, VPEVL, NewMask);
+        auto *N = new VPWidenEVLLoadRecipe(L, VPEVL, NewMask);
         N->insertBefore(L);
         L->replaceAllUsesWith(N);
         L->eraseFromParent();
       } else if (auto *S = dyn_cast<VPWidenStoreRecipe>(MemR)) {
-        auto *N = new VPWidenVPStoreRecipe(S, VPEVL, NewMask);
+        auto *N = new VPWidenEVLStoreRecipe(S, VPEVL, NewMask);
         N->insertBefore(S);
         S->eraseFromParent();
       } else {
         llvm_unreachable("unsupported recipe");
       }
     }
-  });
+    recursivelyDeleteDeadRecipes(HeaderMask);
+  }
   // Replace all uses of VPCanonicalIVPHIRecipe by
   // VPEVLBasedIVPHIRecipe except for the canonical IV increment.
   CanonicalIVPHI->replaceAllUsesWith(EVLPhi);
diff --git a/llvm/lib/Transforms/Vectorize/VPlanValue.h b/llvm/lib/Transforms/Vectorize/VPlanValue.h
index ae9bf2b2af72df..b2194f0fee5f85 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanValue.h
+++ b/llvm/lib/Transforms/Vectorize/VPlanValue.h
@@ -358,8 +358,8 @@ class VPDef {
     VPWidenGEPSC,
     VPWidenLoadSC,
     VPWidenStoreSC,
-    VPWidenVPLoadSC,
-    VPWidenVPStoreSC,
+    VPWidenEVLLoadSC,
+    VPWidenEVLStoreSC,
     VPWidenSC,
     VPWidenSelectSC,
     VPBlendSC,



More information about the llvm-commits mailing list