[llvm] [SLP][NFC]Unify ScalarToTreeEntries and MultiNodeScalars, NFC (PR #124914)

Alexey Bataev via llvm-commits llvm-commits at lists.llvm.org
Wed Jan 29 04:20:35 PST 2025


https://github.com/alexey-bataev updated https://github.com/llvm/llvm-project/pull/124914

>From 3219906b39a4fe5969c1933ce6c26110de153fb7 Mon Sep 17 00:00:00 2001
From: Alexey Bataev <a.bataev at outlook.com>
Date: Wed, 29 Jan 2025 12:15:26 +0000
Subject: [PATCH 1/2] =?UTF-8?q?[=F0=9D=98=80=F0=9D=97=BD=F0=9D=97=BF]=20in?=
 =?UTF-8?q?itial=20version?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

Created using spr 1.3.5
---
 .../Transforms/Vectorize/SLPVectorizer.cpp    | 516 +++++++++---------
 1 file changed, 244 insertions(+), 272 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index 2532edc5d86990..790c4dba0dc36b 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -1476,8 +1476,7 @@ class BoUpSLP {
   /// Clear the internal data structures that are created by 'buildTree'.
   void deleteTree() {
     VectorizableTree.clear();
-    ScalarToTreeEntry.clear();
-    MultiNodeScalars.clear();
+    ScalarToTreeEntries.clear();
     MustGather.clear();
     NonScheduledFirst.clear();
     EntryToLastInstruction.clear();
@@ -1760,7 +1759,7 @@ class BoUpSLP {
 
             auto AllUsersVectorized = [U1, U2, this](Value *V) {
               return llvm::all_of(V->users(), [U1, U2, this](Value *U) {
-                return U == U1 || U == U2 || R.getTreeEntry(U) != nullptr;
+                return U == U1 || U == U2 || R.isVectorized(U);
               });
             };
             return AllUsersVectorized(V1) && AllUsersVectorized(V2);
@@ -1776,9 +1775,13 @@ class BoUpSLP {
       }
 
       auto CheckSameEntryOrFail = [&]() {
-        if (const TreeEntry *TE1 = R.getTreeEntry(V1);
-            TE1 && TE1 == R.getTreeEntry(V2))
-          return LookAheadHeuristics::ScoreSplatLoads;
+        if (ArrayRef<TreeEntry *> TEs1 = R.getTreeEntries(V1); !TEs1.empty()) {
+          SmallPtrSet<TreeEntry *, 4> Set(TEs1.begin(), TEs1.end());
+          if (ArrayRef<TreeEntry *> TEs2 = R.getTreeEntries(V2);
+              !TEs2.empty() &&
+              any_of(TEs2, [&](TreeEntry *E) { return Set.contains(E); }))
+            return LookAheadHeuristics::ScoreSplatLoads;
+        }
         return LookAheadHeuristics::ScoreFail;
       };
 
@@ -2851,13 +2854,7 @@ class BoUpSLP {
         continue;
       auto *I = cast<Instruction>(V);
       salvageDebugInfo(*I);
-      SmallVector<const TreeEntry *> Entries;
-      if (const TreeEntry *Entry = getTreeEntry(I)) {
-        Entries.push_back(Entry);
-        auto It = MultiNodeScalars.find(I);
-        if (It != MultiNodeScalars.end())
-          Entries.append(It->second.begin(), It->second.end());
-      }
+      ArrayRef<TreeEntry *> Entries = getTreeEntries(I);
       for (Use &U : I->operands()) {
         if (auto *OpI = dyn_cast_if_present<Instruction>(U.get());
             OpI && !DeletedInstructions.contains(OpI) && OpI->hasOneUser() &&
@@ -2961,7 +2958,11 @@ class BoUpSLP {
   }
 
   /// Check if the value is vectorized in the tree.
-  bool isVectorized(Value *V) const { return getTreeEntry(V); }
+  bool isVectorized(Value *V) const {
+    assert(V && "V cannot be nullptr.");
+    return ScalarToTreeEntries.contains(V);
+  }
+
 
   ~BoUpSLP();
 
@@ -2999,16 +3000,10 @@ class BoUpSLP {
     ArrayRef<Value *> VL = UserTE->getOperand(OpIdx);
     TreeEntry *TE = nullptr;
     const auto *It = find_if(VL, [&](Value *V) {
-      TE = getTreeEntry(V);
-      if (TE && is_contained(TE->UserTreeIndices, EdgeInfo(UserTE, OpIdx)))
-        return true;
-      auto It = MultiNodeScalars.find(V);
-      if (It != MultiNodeScalars.end()) {
-        for (TreeEntry *E : It->second) {
-          if (is_contained(E->UserTreeIndices, EdgeInfo(UserTE, OpIdx))) {
-            TE = E;
-            return true;
-          }
+      for (TreeEntry *E : getTreeEntries(V)) {
+        if (is_contained(E->UserTreeIndices, EdgeInfo(UserTE, OpIdx))) {
+          TE = E;
+          return true;
         }
       }
       return false;
@@ -3659,18 +3654,24 @@ class BoUpSLP {
       Last->ReorderIndices.append(ReorderIndices.begin(), ReorderIndices.end());
     }
     if (!Last->isGather()) {
+      SmallPtrSet<Value *, 4> Processed;
       for (Value *V : VL) {
         if (isa<PoisonValue>(V))
           continue;
-        const TreeEntry *TE = getTreeEntry(V);
-        assert((!TE || TE == Last || doesNotNeedToBeScheduled(V)) &&
-               "Scalar already in tree!");
-        if (TE) {
-          if (TE != Last)
-            MultiNodeScalars.try_emplace(V).first->getSecond().push_back(Last);
-          continue;
+        auto It = ScalarToTreeEntries.find(V);
+        assert(
+            (It == ScalarToTreeEntries.end() ||
+             (It->getSecond().size() == 1 && It->getSecond().front() == Last) ||
+             doesNotNeedToBeScheduled(V)) &&
+            "Scalar already in tree!");
+        if (It == ScalarToTreeEntries.end()) {
+          ScalarToTreeEntries.try_emplace(V).first->getSecond().push_back(Last);
+          (void)Processed.insert(V);
+        } else if (Processed.insert(V).second) {
+          assert(!is_contained(It->getSecond(), Last) &&
+                 "Value already associated with the node.");
+          It->getSecond().push_back(Last);
         }
-        ScalarToTreeEntry[V] = Last;
       }
       // Update the scheduler bundle to point to this TreeEntry.
       ScheduleData *BundleMember = *Bundle;
@@ -3725,14 +3726,23 @@ class BoUpSLP {
   }
 #endif
 
-  TreeEntry *getTreeEntry(Value *V) {
+  /// Get list of vector entries, associated with the value \p V.
+  ArrayRef<TreeEntry *> getTreeEntries(Value *V) const {
     assert(V && "V cannot be nullptr.");
-    return ScalarToTreeEntry.lookup(V);
+    auto It = ScalarToTreeEntries.find(V);
+    if (It == ScalarToTreeEntries.end())
+      return {};
+    return It->getSecond();
   }
 
-  const TreeEntry *getTreeEntry(Value *V) const {
+  /// Returns first vector node for value \p V, matching values \p VL.
+  TreeEntry *getSameValuesTreeEntry(Value *V, ArrayRef<Value *> VL,
+                                    bool SameVF = false) const {
     assert(V && "V cannot be nullptr.");
-    return ScalarToTreeEntry.lookup(V);
+    for (TreeEntry *TE : ScalarToTreeEntries.lookup(V))
+      if ((!SameVF || TE->getVectorFactor() == VL.size()) && TE->isSame(VL))
+        return TE;
+    return nullptr;
   }
 
   /// Check that the operand node of alternate node does not generate
@@ -3752,12 +3762,8 @@ class BoUpSLP {
                                OrdersType &CurrentOrder,
                                SmallVectorImpl<Value *> &PointerOps);
 
-  /// Maps a specific scalar to its tree entry.
-  SmallDenseMap<Value *, TreeEntry *> ScalarToTreeEntry;
-
-  /// List of scalars, used in several vectorize nodes, and the list of the
-  /// nodes.
-  SmallDenseMap<Value *, SmallVector<TreeEntry *>> MultiNodeScalars;
+  /// Maps a specific scalar to its tree entry(ies).
+  SmallDenseMap<Value *, SmallVector<TreeEntry *>> ScalarToTreeEntries;
 
   /// Maps a value to the proposed vectorizable size.
   SmallDenseMap<Value *, unsigned> InstrElementSize;
@@ -3798,16 +3804,19 @@ class BoUpSLP {
 
   /// This POD struct describes one external user in the vectorized tree.
   struct ExternalUser {
-    ExternalUser(Value *S, llvm::User *U, int L)
-        : Scalar(S), User(U), Lane(L) {}
+    ExternalUser(Value *S, llvm::User *U, const TreeEntry &E, int L)
+        : Scalar(S), User(U), E(E), Lane(L) {}
+
+    /// Which scalar in our function.
+    Value *Scalar = nullptr;
 
-    // Which scalar in our function.
-    Value *Scalar;
+    /// Which user that uses the scalar.
+    llvm::User *User = nullptr;
 
-    // Which user that uses the scalar.
-    llvm::User *User;
+    /// Vector node, the value is part of.
+    const TreeEntry &E;
 
-    // Which lane does the scalar belong to.
+    /// Which lane does the scalar belong to.
     int Lane;
   };
   using UserList = SmallVector<ExternalUser, 16>;
@@ -5113,7 +5122,7 @@ BoUpSLP::canVectorizeLoads(ArrayRef<Value *> VL, const Value *VL0,
     auto IsAnyPointerUsedOutGraph =
         IsPossibleStrided && any_of(PointerOps, [&](Value *V) {
           return isa<Instruction>(V) && any_of(V->users(), [&](User *U) {
-                   return !getTreeEntry(U) && !MustGather.contains(U);
+                   return !isVectorized(U) && !MustGather.contains(U);
                  });
         });
     const unsigned AbsoluteDiff = std::abs(*Diff);
@@ -6572,7 +6581,7 @@ void BoUpSLP::buildExternalUses(
         LLVM_DEBUG(dbgs() << "SLP: Need to extract: Extra arg from lane "
                           << FoundLane << " from " << *Scalar << ".\n");
         ScalarToExtUses.try_emplace(Scalar, ExternalUses.size());
-        ExternalUses.emplace_back(Scalar, nullptr, FoundLane);
+        ExternalUses.emplace_back(Scalar, nullptr, *Entry, FoundLane);
         continue;
       }
       for (User *U : Scalar->users()) {
@@ -6587,16 +6596,24 @@ void BoUpSLP::buildExternalUses(
           continue;
 
         // Skip in-tree scalars that become vectors
-        if (TreeEntry *UseEntry = getTreeEntry(U)) {
+        if (ArrayRef<TreeEntry *> UseEntries = getTreeEntries(U);
+            !UseEntries.empty()) {
           // Some in-tree scalars will remain as scalar in vectorized
           // instructions. If that is the case, the one in FoundLane will
           // be used.
-          if (UseEntry->State == TreeEntry::ScatterVectorize ||
-              !doesInTreeUserNeedToExtract(
-                  Scalar, getRootEntryInstruction(*UseEntry), TLI, TTI)) {
+          if (any_of(UseEntries, [&](TreeEntry *UseEntry) {
+                return UseEntry->State == TreeEntry::ScatterVectorize ||
+                       !doesInTreeUserNeedToExtract(
+                           Scalar, getRootEntryInstruction(*UseEntry), TLI,
+                           TTI);
+              })) {
             LLVM_DEBUG(dbgs() << "SLP: \tInternal user will be removed:" << *U
                               << ".\n");
-            assert(!UseEntry->isGather() && "Bad state");
+            assert(none_of(UseEntries,
+                           [](TreeEntry *UseEntry) {
+                             return UseEntry->isGather();
+                           }) &&
+                   "Bad state");
             continue;
           }
           U = nullptr;
@@ -6613,7 +6630,7 @@ void BoUpSLP::buildExternalUses(
                           << " from lane " << FoundLane << " from " << *Scalar
                           << ".\n");
         It = ScalarToExtUses.try_emplace(Scalar, ExternalUses.size()).first;
-        ExternalUses.emplace_back(Scalar, U, FoundLane);
+        ExternalUses.emplace_back(Scalar, U, *Entry, FoundLane);
         if (!U)
           break;
       }
@@ -6644,7 +6661,7 @@ BoUpSLP::collectUserStores(const BoUpSLP::TreeEntry *TE) const {
           !isValidElementType(SI->getValueOperand()->getType()))
         continue;
       // Skip entry if already
-      if (getTreeEntry(U))
+      if (isVectorized(U))
         continue;
 
       Value *Ptr =
@@ -7027,10 +7044,11 @@ void BoUpSLP::tryToVectorizeGatheredLoads(
               for (User *U : LI->users()) {
                 if (auto *UI = dyn_cast<Instruction>(U); UI && isDeleted(UI))
                   continue;
-                if (const TreeEntry *UTE = getTreeEntry(U)) {
+                for (const TreeEntry *UTE : getTreeEntries(U)) {
                   for (int I : seq<int>(UTE->getNumOperands())) {
-                    if (all_of(UTE->getOperand(I),
-                               [LI](Value *V) { return V == LI; }))
+                    if (all_of(UTE->getOperand(I), [LI](Value *V) {
+                          return V == LI || isa<PoisonValue>(V);
+                        }))
                       // Found legal broadcast - do not vectorize.
                       return false;
                   }
@@ -7135,7 +7153,7 @@ void BoUpSLP::tryToVectorizeGatheredLoads(
           int LastDist = LocalLoadsDists.front().second;
           bool AllowMaskedGather = IsMaskedGatherSupported(OriginalLoads);
           for (const std::pair<LoadInst *, int> &L : LocalLoadsDists) {
-            if (getTreeEntry(L.first))
+            if (isVectorized(L.first))
               continue;
             assert(LastDist >= L.second &&
                    "Expected first distance always not less than second");
@@ -7187,9 +7205,9 @@ void BoUpSLP::tryToVectorizeGatheredLoads(
           for (auto [Slice, _] : Results) {
             LLVM_DEBUG(dbgs() << "SLP: Trying to vectorize gathered loads ("
                               << Slice.size() << ")\n");
-            if (any_of(Slice, [&](Value *V) { return getTreeEntry(V); })) {
+            if (any_of(Slice, [&](Value *V) { return isVectorized(V); })) {
               for (Value *L : Slice)
-                if (!getTreeEntry(L))
+                if (!isVectorized(L))
                   SortedNonVectorized.push_back(cast<LoadInst>(L));
               continue;
             }
@@ -7228,7 +7246,7 @@ void BoUpSLP::tryToVectorizeGatheredLoads(
                         any_of(E->Scalars, [&, Slice = Slice](Value *V) {
                           if (isa<Constant>(V))
                             return false;
-                          if (getTreeEntry(V))
+                          if (isVectorized(V))
                             return true;
                           const auto &Nodes = ValueToGatherNodes.at(V);
                           return (Nodes.size() != 1 || !Nodes.contains(E)) &&
@@ -7315,7 +7333,7 @@ void BoUpSLP::tryToVectorizeGatheredLoads(
               for (unsigned I = 0, E = Slice.size(); I < E; I += VF) {
                 ArrayRef<Value *> SubSlice =
                     Slice.slice(I, std::min(VF, E - I));
-                if (getTreeEntry(SubSlice.front()))
+                if (isVectorized(SubSlice.front()))
                   continue;
                 // Check if the subslice is to be-vectorized entry, which is not
                 // equal to entry.
@@ -7585,7 +7603,7 @@ bool BoUpSLP::areAltOperandsProfitable(const InstructionsState &S,
                    DenseMap<Value *, unsigned> Uniques;
                    for (Value *V : Op) {
                      if (isa<Constant, ExtractElementInst>(V) ||
-                         getTreeEntry(V) || (L && L->isLoopInvariant(V))) {
+                         isVectorized(V) || (L && L->isLoopInvariant(V))) {
                        if (isa<UndefValue>(V))
                          ++UndefCnt;
                        continue;
@@ -7603,7 +7621,7 @@ bool BoUpSLP::areAltOperandsProfitable(const InstructionsState &S,
                    return none_of(Uniques, [&](const auto &P) {
                      return P.first->hasNUsesOrMore(P.second + 1) &&
                             none_of(P.first->users(), [&](User *U) {
-                              return getTreeEntry(U) || Uniques.contains(U);
+                              return isVectorized(U) || Uniques.contains(U);
                             });
                    });
                  }) ||
@@ -8167,59 +8185,25 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
 
   // Check if this is a duplicate of another entry.
   if (S) {
-    if (TreeEntry *E = getTreeEntry(S.getMainOp())) {
-      LLVM_DEBUG(dbgs() << "SLP: \tChecking bundle: " << *S.getMainOp()
-                        << ".\n");
-      if (GatheredLoadsEntriesFirst.has_value() || !E->isSame(VL)) {
-        auto It = MultiNodeScalars.find(S.getMainOp());
-        if (It != MultiNodeScalars.end()) {
-          auto *TEIt = find_if(It->getSecond(),
-                               [&](TreeEntry *ME) { return ME->isSame(VL); });
-          if (TEIt != It->getSecond().end())
-            E = *TEIt;
-          else
-            E = nullptr;
-        } else {
-          E = nullptr;
-        }
-      }
-      if (!E) {
-        if (!doesNotNeedToBeScheduled(S.getMainOp())) {
-          LLVM_DEBUG(dbgs() << "SLP: Gathering due to partial overlap.\n");
-          if (TryToFindDuplicates(S))
-            newTreeEntry(VL, std::nullopt /*not vectorized*/, S, UserTreeIdx,
-                         ReuseShuffleIndices);
-          return;
-        }
-        SmallPtrSet<const TreeEntry *, 4> Nodes;
-        Nodes.insert(getTreeEntry(S.getMainOp()));
-        for (const TreeEntry *E : MultiNodeScalars.lookup(S.getMainOp()))
-          Nodes.insert(E);
-        SmallPtrSet<Value *, 8> Values(VL.begin(), VL.end());
-        if (any_of(Nodes, [&](const TreeEntry *E) {
-              if (all_of(E->Scalars,
-                         [&](Value *V) { return Values.contains(V); }))
-                return true;
-              SmallPtrSet<Value *, 8> EValues(E->Scalars.begin(),
-                                              E->Scalars.end());
-              return (
-                  all_of(VL, [&](Value *V) { return EValues.contains(V); }));
-            })) {
-          LLVM_DEBUG(dbgs() << "SLP: Gathering due to full overlap.\n");
-          if (TryToFindDuplicates(S))
-            newTreeEntry(VL, std::nullopt /*not vectorized*/, S, UserTreeIdx,
-                         ReuseShuffleIndices);
-          return;
-        }
-      } else {
-        // Record the reuse of the tree node.  FIXME, currently this is only
-        // used to properly draw the graph rather than for the actual
-        // vectorization.
+    LLVM_DEBUG(dbgs() << "SLP: \tChecking bundle: " << *S.getMainOp() << ".\n");
+    for (TreeEntry *E : getTreeEntries(S.getMainOp())) {
+      if (E->isSame(VL)) {
+        // Record the reuse of the tree node.
         E->UserTreeIndices.push_back(UserTreeIdx);
         LLVM_DEBUG(dbgs() << "SLP: Perfect diamond merge at " << *S.getMainOp()
                           << ".\n");
         return;
       }
+      SmallPtrSet<Value *, 8> Values(E->Scalars.begin(), E->Scalars.end());
+      if (all_of(VL, [&](Value *V) {
+            return isa<PoisonValue>(V) || Values.contains(V);
+          })) {
+        LLVM_DEBUG(dbgs() << "SLP: Gathering due to full overlap.\n");
+        if (TryToFindDuplicates(S))
+          newTreeEntry(VL, std::nullopt /*not vectorized*/, S, UserTreeIdx,
+                       ReuseShuffleIndices);
+        return;
+      }
     }
   }
 
@@ -8371,7 +8355,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
     if ((!IsScatterVectorizeUserTE && !isa<Instruction>(V)) ||
         doesNotNeedToBeScheduled(V))
       continue;
-    if (getTreeEntry(V)) {
+    if (isVectorized(V)) {
       LLVM_DEBUG(dbgs() << "SLP: The instruction (" << *V
                         << ") is already in tree.\n");
       if (TryToFindDuplicates(S))
@@ -9029,8 +9013,7 @@ bool BoUpSLP::areAllUsersVectorized(
     Instruction *I, const SmallDenseSet<Value *> *VectorizedVals) const {
   return (I->hasOneUse() && (!VectorizedVals || VectorizedVals->contains(I))) ||
          all_of(I->users(), [this](User *U) {
-           return ScalarToTreeEntry.contains(U) ||
-                  isVectorLikeInstWithConstOps(U) ||
+           return isVectorized(U) || isVectorLikeInstWithConstOps(U) ||
                   (isa<ExtractElementInst>(U) && MustGather.contains(U));
          });
 }
@@ -9844,13 +9827,9 @@ void BoUpSLP::transformNodes() {
           ArrayRef<Value *> Slice = VL.slice(Cnt, VF);
           // If any instruction is vectorized already - do not try again.
           // Reuse the existing node, if it fully matches the slice.
-          if (const TreeEntry *SE = getTreeEntry(Slice.front());
-              SE || getTreeEntry(Slice.back())) {
-            if (!SE)
-              continue;
-            if (VF != SE->getVectorFactor() || !SE->isSame(Slice))
-              continue;
-          }
+          if (isVectorized(Slice.front()) &&
+              !getSameValuesTreeEntry(Slice.front(), Slice, /*SameVF=*/true))
+            continue;
           // Constant already handled effectively - skip.
           if (allConstant(Slice))
             continue;
@@ -9933,12 +9912,8 @@ void BoUpSLP::transformNodes() {
         for (auto [Cnt, Sz] : Slices) {
           ArrayRef<Value *> Slice = VL.slice(Cnt, Sz);
           // If any instruction is vectorized already - do not try again.
-          if (TreeEntry *SE = getTreeEntry(Slice.front());
-              SE || getTreeEntry(Slice.back())) {
-            if (!SE)
-              continue;
-            if (VF != SE->getVectorFactor() || !SE->isSame(Slice))
-              continue;
+          if (TreeEntry *SE = getSameValuesTreeEntry(Slice.front(), Slice,
+                                                     /*SameVF=*/true)) {
             SE->UserTreeIndices.emplace_back(&E, UINT_MAX);
             AddCombinedNode(SE->Idx, Cnt, Sz);
             continue;
@@ -10724,7 +10699,7 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
         auto *EE = cast<ExtractElementInst>(V);
         VecBase = EE->getVectorOperand();
         UniqueBases.insert(VecBase);
-        const TreeEntry *VE = R.getTreeEntry(V);
+        ArrayRef<TreeEntry *> VEs = R.getTreeEntries(V);
         if (!CheckedExtracts.insert(V).second ||
             !R.areAllUsersVectorized(cast<Instruction>(V), &VectorizedVals) ||
             any_of(EE->users(),
@@ -10733,7 +10708,7 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
                             !R.areAllUsersVectorized(cast<Instruction>(U),
                                                      &VectorizedVals);
                    }) ||
-            (VE && VE != E))
+            (!VEs.empty() && !is_contained(VEs, E)))
           continue;
         std::optional<unsigned> EEIdx = getExtractIndex(EE);
         if (!EEIdx)
@@ -11166,13 +11141,14 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
   const unsigned Sz = UniqueValues.size();
   SmallBitVector UsedScalars(Sz, false);
   for (unsigned I = 0; I < Sz; ++I) {
-    if (isa<Instruction>(UniqueValues[I]) && getTreeEntry(UniqueValues[I]) == E)
+    if (isa<Instruction>(UniqueValues[I]) &&
+        is_contained(getTreeEntries(UniqueValues[I]), E))
       continue;
     UsedScalars.set(I);
   }
   auto GetCastContextHint = [&](Value *V) {
-    if (const TreeEntry *OpTE = getTreeEntry(V))
-      return getCastContextHint(*OpTE);
+    if (ArrayRef<TreeEntry *> OpTEs = getTreeEntries(V); OpTEs.size() == 1)
+      return getCastContextHint(*OpTEs.front());
     InstructionsState SrcState = getSameOpcode(E->getOperand(0), *TLI);
     if (SrcState && SrcState.getOpcode() == Instruction::Load &&
         !SrcState.isAltShuffle())
@@ -11294,11 +11270,12 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
         Value *Op = PHI->getIncomingValue(I);
         Operands[I] = Op;
       }
-      if (const TreeEntry *OpTE = getTreeEntry(Operands.front()))
-        if (OpTE->isSame(Operands) && CountedOps.insert(OpTE).second)
-          if (!OpTE->ReuseShuffleIndices.empty())
-            ScalarCost += TTI::TCC_Basic * (OpTE->ReuseShuffleIndices.size() -
-                                            OpTE->Scalars.size());
+      if (const TreeEntry *OpTE =
+              getSameValuesTreeEntry(Operands.front(), Operands))
+        if (CountedOps.insert(OpTE).second &&
+            !OpTE->ReuseShuffleIndices.empty())
+          ScalarCost += TTI::TCC_Basic * (OpTE->ReuseShuffleIndices.size() -
+                                          OpTE->Scalars.size());
     }
 
     return CommonCost - ScalarCost;
@@ -12231,7 +12208,7 @@ InstructionCost BoUpSLP::getSpillCost() const {
     // Update LiveValues.
     LiveValues.erase(PrevInst);
     for (auto &J : PrevInst->operands()) {
-      if (isa<Instruction>(&*J) && getTreeEntry(&*J))
+      if (isa<Instruction>(&*J) && isVectorized(&*J))
         LiveValues.insert(cast<Instruction>(&*J));
     }
 
@@ -12478,9 +12455,9 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) {
       continue;
     }
     if (TE.isGather() && TE.hasState()) {
-      if (const TreeEntry *E = getTreeEntry(TE.getMainOp());
-          E && E->getVectorFactor() == TE.getVectorFactor() &&
-          E->isSame(TE.Scalars)) {
+      if (const TreeEntry *E =
+              getSameValuesTreeEntry(TE.getMainOp(), TE.Scalars);
+          E && E->getVectorFactor() == TE.getVectorFactor()) {
         // Some gather nodes might be absolutely the same as some vectorizable
         // nodes after reordering, need to handle it.
         LLVM_DEBUG(dbgs() << "SLP: Adding cost 0 for bundle "
@@ -12552,7 +12529,7 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) {
           continue;
         std::optional<unsigned> InsertIdx = getElementIndex(VU);
         if (InsertIdx) {
-          const TreeEntry *ScalarTE = getTreeEntry(EU.Scalar);
+          const TreeEntry *ScalarTE = &EU.E;
           auto *It = find_if(
               ShuffledInserts,
               [this, VU](const ShuffledInsertData<const TreeEntry *> &Data) {
@@ -12561,7 +12538,7 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) {
                 return areTwoInsertFromSameBuildVector(
                     VU, VecInsert, [this](InsertElementInst *II) -> Value * {
                       Value *Op0 = II->getOperand(0);
-                      if (getTreeEntry(II) && !getTreeEntry(Op0))
+                      if (isVectorized(II) && !isVectorized(Op0))
                         return nullptr;
                       return Op0;
                     });
@@ -12619,7 +12596,7 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) {
     // for the extract and the added cost of the sign extend if needed.
     InstructionCost ExtraCost = TTI::TCC_Free;
     auto *VecTy = getWidenedType(EU.Scalar->getType(), BundleWidth);
-    const TreeEntry *Entry = getTreeEntry(EU.Scalar);
+    const TreeEntry *Entry = &EU.E;
     auto It = MinBWs.find(Entry);
     if (It != MinBWs.end()) {
       auto *MinTy = IntegerType::get(F->getContext(), It->second.first);
@@ -12662,7 +12639,7 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) {
       auto *Inst = cast<Instruction>(EU.Scalar);
       InstructionCost ScalarCost = TTI->getInstructionCost(Inst, CostKind);
       auto OperandIsScalar = [&](Value *V) {
-        if (!getTreeEntry(V)) {
+        if (!isVectorized(V)) {
           // Some extractelements might be not vectorized, but
           // transformed into shuffle and removed from the function,
           // consider it here.
@@ -12678,7 +12655,7 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) {
         if (auto *Op = dyn_cast<Instruction>(CI->getOperand(0));
             Op && all_of(Op->operands(), OperandIsScalar)) {
           InstructionCost OpCost =
-              (getTreeEntry(Op) && !ValueToExtUses->contains(Op))
+              (isVectorized(Op) && !ValueToExtUses->contains(Op))
                   ? TTI->getInstructionCost(Op, CostKind)
                   : 0;
           if (ScalarCost + OpCost <= ExtraCost) {
@@ -12705,7 +12682,7 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) {
                                   cast<Instruction>(
                                       VectorizableTree.front()->getMainOp())
                                       ->getParent()) &&
-                             !getTreeEntry(U);
+                             !isVectorized(U);
                     }) &&
             count_if(Entry->Scalars, [&](Value *V) {
               return ValueToExtUses->contains(V);
@@ -12767,8 +12744,9 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) {
   // instead of extractelement.
   for (Value *V : ScalarOpsFromCasts) {
     ExternalUsesAsOriginalScalar.insert(V);
-    if (const TreeEntry *E = getTreeEntry(V)) {
-      ExternalUses.emplace_back(V, nullptr, E->findLaneForValue(V));
+    if (ArrayRef<TreeEntry *> TEs = getTreeEntries(V); !TEs.empty()) {
+      ExternalUses.emplace_back(V, nullptr, *TEs.front(),
+                                TEs.front()->findLaneForValue(V));
     }
   }
   // Add reduced value cost, if resized.
@@ -13188,21 +13166,18 @@ BoUpSLP::isGatherShuffledSingleRegisterEntry(
         continue;
       VToTEs.insert(TEPtr);
     }
-    if (const TreeEntry *VTE = getTreeEntry(V)) {
-      if (ForOrder && VTE->Idx < GatheredLoadsEntriesFirst.value_or(0)) {
-        if (VTE->State != TreeEntry::Vectorize) {
-          auto It = MultiNodeScalars.find(V);
-          if (It == MultiNodeScalars.end())
-            continue;
-          VTE = *It->getSecond().begin();
-          // Iterate through all vectorized nodes.
-          auto *MIt = find_if(It->getSecond(), [](const TreeEntry *MTE) {
-            return MTE->State == TreeEntry::Vectorize;
-          });
-          if (MIt == It->getSecond().end())
-            continue;
-          VTE = *MIt;
-        }
+    if (ArrayRef<TreeEntry *> VTEs = getTreeEntries(V); !VTEs.empty()) {
+      const TreeEntry *VTE = VTEs.front();
+      if (ForOrder && VTE->Idx < GatheredLoadsEntriesFirst.value_or(0) &&
+          VTEs.size() > 1 && VTE->State != TreeEntry::Vectorize) {
+        VTEs = VTEs.drop_front();
+        // Iterate through all vectorized nodes.
+        const auto *MIt = find_if(VTEs, [](const TreeEntry *MTE) {
+          return MTE->State == TreeEntry::Vectorize;
+        });
+        if (MIt == VTEs.end())
+          continue;
+        VTE = *MIt;
       }
       if (none_of(TE->CombinedEntriesWithIndices,
                   [&](const auto &P) { return P.first == VTE->Idx; })) {
@@ -13366,7 +13341,7 @@ BoUpSLP::isGatherShuffledSingleRegisterEntry(
   // by extractelements processing) or may form vector node in future.
   auto MightBeIgnored = [=](Value *V) {
     auto *I = dyn_cast<Instruction>(V);
-    return I && !IsSplatOrUndefs && !ScalarToTreeEntry.count(I) &&
+    return I && !IsSplatOrUndefs && !isVectorized(I) &&
            !isVectorLikeInstWithConstOps(I) &&
            !areAllUsersVectorized(I, UserIgnoreList) && isSimple(I);
   };
@@ -13952,7 +13927,7 @@ Value *BoUpSLP::gather(
   for (int I = 0, E = VL.size(); I < E; ++I) {
     if (auto *Inst = dyn_cast<Instruction>(VL[I]))
       if ((CheckPredecessor(Inst->getParent(), Builder.GetInsertBlock()) ||
-           getTreeEntry(Inst) ||
+           isVectorized(Inst) ||
            (L && (!Root || L->isLoopInvariant(Root)) && L->contains(Inst))) &&
           PostponedIndices.insert(I).second)
         PostponedInsts.emplace_back(Inst, I);
@@ -13969,7 +13944,7 @@ Value *BoUpSLP::gather(
           isa_and_nonnull<SExtInst, ZExtInst>(CI)) {
         Value *Op = CI->getOperand(0);
         if (auto *IOp = dyn_cast<Instruction>(Op);
-            !IOp || !(isDeleted(IOp) || getTreeEntry(IOp)))
+            !IOp || !(isDeleted(IOp) || isVectorized(IOp)))
           V = Op;
       }
       Scalar = Builder.CreateIntCast(
@@ -13995,7 +13970,7 @@ Value *BoUpSLP::gather(
     CSEBlocks.insert(InsElt->getParent());
     // Add to our 'need-to-extract' list.
     if (isa<Instruction>(V)) {
-      if (TreeEntry *Entry = getTreeEntry(V)) {
+      if (ArrayRef<TreeEntry *> Entries = getTreeEntries(V); !Entries.empty()) {
         // Find which lane we need to extract.
         User *UserOp = nullptr;
         if (Scalar != V) {
@@ -14005,8 +13980,8 @@ Value *BoUpSLP::gather(
           UserOp = InsElt;
         }
         if (UserOp) {
-          unsigned FoundLane = Entry->findLaneForValue(V);
-          ExternalUses.emplace_back(V, UserOp, FoundLane);
+          unsigned FoundLane = Entries.front()->findLaneForValue(V);
+          ExternalUses.emplace_back(V, UserOp, *Entries.front(), FoundLane);
         }
       }
     }
@@ -14241,8 +14216,8 @@ class BoUpSLP::ShuffleInstructionBuilder final : public BaseShuffleAnalysis {
         continue;
       auto *EI = cast<ExtractElementInst>(VL[I]);
       VecBase = EI->getVectorOperand();
-      if (const TreeEntry *TE = R.getTreeEntry(VecBase))
-        VecBase = TE->VectorizedValue;
+      if (ArrayRef<TreeEntry *> TEs = R.getTreeEntries(VecBase); !TEs.empty())
+        VecBase = TEs.front()->VectorizedValue;
       assert(VecBase && "Expected vectorized value.");
       UniqueBases.insert(VecBase);
       // If the only one use is vectorized - can delete the extractelement
@@ -14250,18 +14225,20 @@ class BoUpSLP::ShuffleInstructionBuilder final : public BaseShuffleAnalysis {
       if (!EI->hasOneUse() || R.ExternalUsesAsOriginalScalar.contains(EI) ||
           (NumParts != 1 && count(VL, EI) > 1) ||
           any_of(EI->users(), [&](User *U) {
-            const TreeEntry *UTE = R.getTreeEntry(U);
-            return !UTE || R.MultiNodeScalars.contains(U) ||
+            ArrayRef<TreeEntry *> UTEs = R.getTreeEntries(U);
+            return UTEs.empty() || UTEs.size() > 1 ||
                    (isa<GetElementPtrInst>(U) &&
                     !R.areAllUsersVectorized(cast<Instruction>(U))) ||
-                   count_if(R.VectorizableTree,
-                            [&](const std::unique_ptr<TreeEntry> &TE) {
-                              return any_of(TE->UserTreeIndices,
-                                            [&](const EdgeInfo &Edge) {
-                                              return Edge.UserTE == UTE;
-                                            }) &&
-                                     is_contained(VL, EI);
-                            }) != 1;
+                   (!UTEs.empty() &&
+                    count_if(R.VectorizableTree,
+                             [&](const std::unique_ptr<TreeEntry> &TE) {
+                               return any_of(TE->UserTreeIndices,
+                                             [&](const EdgeInfo &Edge) {
+                                               return Edge.UserTE ==
+                                                      UTEs.front();
+                                             }) &&
+                                      is_contained(VL, EI);
+                             }) != 1);
           }))
         continue;
       R.eraseInstruction(EI);
@@ -14296,8 +14273,9 @@ class BoUpSLP::ShuffleInstructionBuilder final : public BaseShuffleAnalysis {
               return S;
             Value *VecOp =
                 cast<ExtractElementInst>(std::get<0>(D))->getVectorOperand();
-            if (const TreeEntry *TE = R.getTreeEntry(VecOp))
-              VecOp = TE->VectorizedValue;
+            if (ArrayRef<TreeEntry *> TEs = R.getTreeEntries(VecOp);
+                !TEs.empty())
+              VecOp = TEs.front()->VectorizedValue;
             assert(VecOp && "Expected vectorized value.");
             const unsigned Size =
                 cast<FixedVectorType>(VecOp->getType())->getNumElements();
@@ -14307,8 +14285,8 @@ class BoUpSLP::ShuffleInstructionBuilder final : public BaseShuffleAnalysis {
         if (I == PoisonMaskElem)
           continue;
         Value *VecOp = cast<ExtractElementInst>(V)->getVectorOperand();
-        if (const TreeEntry *TE = R.getTreeEntry(VecOp))
-          VecOp = TE->VectorizedValue;
+        if (ArrayRef<TreeEntry *> TEs = R.getTreeEntries(VecOp); !TEs.empty())
+          VecOp = TEs.front()->VectorizedValue;
         assert(VecOp && "Expected vectorized value.");
         VecOp = castToScalarTyElem(VecOp);
         Bases[I / VF] = VecOp;
@@ -14634,29 +14612,20 @@ BoUpSLP::TreeEntry *BoUpSLP::getMatchedVectorizedOperand(const TreeEntry *E,
   if (!S)
     return nullptr;
   auto CheckSameVE = [&](const TreeEntry *VE) {
-    return VE->isSame(VL) &&
-           (any_of(VE->UserTreeIndices,
-                   [E, NodeIdx](const EdgeInfo &EI) {
-                     return EI.UserTE == E && EI.EdgeIdx == NodeIdx;
-                   }) ||
-            any_of(VectorizableTree,
-                   [E, NodeIdx, VE](const std::unique_ptr<TreeEntry> &TE) {
-                     return TE->isOperandGatherNode(
-                                {const_cast<TreeEntry *>(E), NodeIdx}) &&
-                            VE->isSame(TE->Scalars);
-                   }));
+    return any_of(VE->UserTreeIndices,
+                  [E, NodeIdx](const EdgeInfo &EI) {
+                    return EI.UserTE == E && EI.EdgeIdx == NodeIdx;
+                  }) ||
+           any_of(VectorizableTree,
+                  [E, NodeIdx, VE](const std::unique_ptr<TreeEntry> &TE) {
+                    return TE->isOperandGatherNode(
+                               {const_cast<TreeEntry *>(E), NodeIdx}) &&
+                           VE->isSame(TE->Scalars);
+                  });
   };
-  TreeEntry *VE = getTreeEntry(S.getMainOp());
+  TreeEntry *VE = getSameValuesTreeEntry(S.getMainOp(), VL);
   if (VE && CheckSameVE(VE))
     return VE;
-  auto It = MultiNodeScalars.find(S.getMainOp());
-  if (It != MultiNodeScalars.end()) {
-    auto *I = find_if(It->getSecond(), [&](const TreeEntry *TE) {
-      return TE != VE && CheckSameVE(TE);
-    });
-    if (I != It->getSecond().end())
-      return *I;
-  }
   return nullptr;
 }
 
@@ -14874,9 +14843,10 @@ ResTy BoUpSLP::processBuildVector(const TreeEntry *E, Type *ScalarTy,
       for (auto [Idx, I] : enumerate(ExtractMask)) {
         if (I == PoisonMaskElem)
           continue;
-        if (const auto *TE = getTreeEntry(
-                cast<ExtractElementInst>(StoredGS[Idx])->getVectorOperand()))
-          ExtractEntries.push_back(TE);
+        if (ArrayRef<TreeEntry *> TEs = getTreeEntries(
+                cast<ExtractElementInst>(StoredGS[Idx])->getVectorOperand());
+            !TEs.empty())
+          ExtractEntries.append(TEs.begin(), TEs.end());
       }
       if (std::optional<ResTy> Delayed =
               ShuffleBuilder.needToDelay(E, ExtractEntries)) {
@@ -14907,10 +14877,10 @@ ResTy BoUpSLP::processBuildVector(const TreeEntry *E, Type *ScalarTy,
           any_of(E->Scalars, IsaPred<LoadInst>)) &&
          any_of(E->Scalars,
                 [this](Value *V) {
-                  return isa<LoadInst>(V) && getTreeEntry(V);
+                  return isa<LoadInst>(V) && isVectorized(V);
                 })) ||
         (E->hasState() && E->isAltShuffle()) ||
-        all_of(E->Scalars, [this](Value *V) { return getTreeEntry(V); }) ||
+        all_of(E->Scalars, [this](Value *V) { return isVectorized(V); }) ||
         isSplat(E->Scalars) ||
         (E->Scalars != GatheredScalars && GatheredScalars.size() <= 2)) {
       GatherShuffles =
@@ -15025,7 +14995,7 @@ ResTy BoUpSLP::processBuildVector(const TreeEntry *E, Type *ScalarTy,
       // non-poisonous, or by freezing the incoming scalar value first.
       auto *It = find_if(Scalars, [this, E](Value *V) {
         return !isa<UndefValue>(V) &&
-               (getTreeEntry(V) || isGuaranteedNotToBePoison(V, AC) ||
+               (isVectorized(V) || isGuaranteedNotToBePoison(V, AC) ||
                 (E->UserTreeIndices.size() == 1 &&
                  any_of(V->uses(), [E](const Use &U) {
                    // Check if the value already used in the same operation in
@@ -15083,9 +15053,9 @@ ResTy BoUpSLP::processBuildVector(const TreeEntry *E, Type *ScalarTy,
             continue;
           auto *EI = cast<ExtractElementInst>(StoredGS[I]);
           Value *VecOp = EI->getVectorOperand();
-          if (const auto *TE = getTreeEntry(VecOp))
-            if (TE->VectorizedValue)
-              VecOp = TE->VectorizedValue;
+          if (ArrayRef<TreeEntry *> TEs = getTreeEntries(VecOp);
+              !TEs.empty() && TEs.front()->VectorizedValue)
+            VecOp = TEs.front()->VectorizedValue;
           if (!Vec1) {
             Vec1 = VecOp;
           } else if (Vec1 != VecOp) {
@@ -15413,8 +15383,8 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
 
     case Instruction::ExtractElement: {
       Value *V = E->getSingleOperand(0);
-      if (const TreeEntry *TE = getTreeEntry(V))
-        V = TE->VectorizedValue;
+      if (ArrayRef<TreeEntry *> TEs = getTreeEntries(V); !TEs.empty())
+        V = TEs.front()->VectorizedValue;
       setInsertPointAfterBundle(E);
       V = FinalShuffle(V, E);
       E->VectorizedValue = V;
@@ -16344,13 +16314,13 @@ BoUpSLP::vectorizeTree(const ExtraValueToDebugLocsMap &ExternallyUsedValues,
   DenseMap<Value *, SmallVector<TreeEntry *>> PostponedValues;
   for (const TreeEntry *E : PostponedNodes) {
     auto *TE = const_cast<TreeEntry *>(E);
-    if (auto *VecTE = getTreeEntry(TE->Scalars.front()))
-      if (VecTE->isSame(TE->UserTreeIndices.front().UserTE->getOperand(
-              TE->UserTreeIndices.front().EdgeIdx)) &&
-          VecTE->isSame(TE->Scalars))
-        // Found gather node which is absolutely the same as one of the
-        // vectorized nodes. It may happen after reordering.
-        continue;
+    if (auto *VecTE = getSameValuesTreeEntry(
+            TE->Scalars.front(), TE->UserTreeIndices.front().UserTE->getOperand(
+                                     TE->UserTreeIndices.front().EdgeIdx));
+        VecTE && VecTE->isSame(TE->Scalars))
+      // Found gather node which is absolutely the same as one of the
+      // vectorized nodes. It may happen after reordering.
+      continue;
     auto *PrevVec = cast<Instruction>(TE->VectorizedValue);
     TE->VectorizedValue = nullptr;
     auto *UserI =
@@ -16392,14 +16362,8 @@ BoUpSLP::vectorizeTree(const ExtraValueToDebugLocsMap &ExternallyUsedValues,
              "Expected integer vector types only.");
       std::optional<bool> IsSigned;
       for (Value *V : TE->Scalars) {
-        if (const TreeEntry *BaseTE = getTreeEntry(V)) {
-          auto It = MinBWs.find(BaseTE);
-          if (It != MinBWs.end()) {
-            IsSigned = IsSigned.value_or(false) || It->second.second;
-            if (*IsSigned)
-              break;
-          }
-          for (const TreeEntry *MNTE : MultiNodeScalars.lookup(V)) {
+        if (isVectorized(V)) {
+          for (const TreeEntry *MNTE : getTreeEntries(V)) {
             auto It = MinBWs.find(MNTE);
             if (It != MinBWs.end()) {
               IsSigned = IsSigned.value_or(false) || It->second.second;
@@ -16475,7 +16439,7 @@ BoUpSLP::vectorizeTree(const ExtraValueToDebugLocsMap &ExternallyUsedValues,
     // has multiple uses of the same value.
     if (User && !is_contained(Scalar->users(), User))
       continue;
-    TreeEntry *E = getTreeEntry(Scalar);
+    const TreeEntry *E = &ExternalUse.E;
     assert(E && "Invalid scalar");
     assert(!E->isGather() && "Extracting from a gather list");
     // Non-instruction pointers are not deleted, just skip them.
@@ -16533,8 +16497,8 @@ BoUpSLP::vectorizeTree(const ExtraValueToDebugLocsMap &ExternallyUsedValues,
                      ES && isa<Instruction>(Vec)) {
             Value *V = ES->getVectorOperand();
             auto *IVec = cast<Instruction>(Vec);
-            if (const TreeEntry *ETE = getTreeEntry(V))
-              V = ETE->VectorizedValue;
+            if (ArrayRef<TreeEntry *> ETEs = getTreeEntries(V); !ETEs.empty())
+              V = ETEs.front()->VectorizedValue;
             if (auto *IV = dyn_cast<Instruction>(V);
                 !IV || IV == Vec || IV->getParent() != IVec->getParent() ||
                 IV->comesBefore(IVec))
@@ -16587,27 +16551,31 @@ BoUpSLP::vectorizeTree(const ExtraValueToDebugLocsMap &ExternallyUsedValues,
     if (!User) {
       if (!ScalarsWithNullptrUser.insert(Scalar).second)
         continue;
-      assert((ExternallyUsedValues.count(Scalar) ||
-              Scalar->hasNUsesOrMore(UsesLimit) ||
-              ExternalUsesAsOriginalScalar.contains(Scalar) ||
-              any_of(Scalar->users(),
-                     [&](llvm::User *U) {
-                       if (ExternalUsesAsOriginalScalar.contains(U))
-                         return true;
-                       TreeEntry *UseEntry = getTreeEntry(U);
-                       return UseEntry &&
-                              (UseEntry->State == TreeEntry::Vectorize ||
-                               UseEntry->State ==
-                                   TreeEntry::StridedVectorize) &&
-                              (E->State == TreeEntry::Vectorize ||
-                               E->State == TreeEntry::StridedVectorize) &&
-                              doesInTreeUserNeedToExtract(
-                                  Scalar, getRootEntryInstruction(*UseEntry),
-                                  TLI, TTI);
-                     })) &&
-             "Scalar with nullptr User must be registered in "
-             "ExternallyUsedValues map or remain as scalar in vectorized "
-             "instructions");
+      assert(
+          (ExternallyUsedValues.count(Scalar) ||
+           Scalar->hasNUsesOrMore(UsesLimit) ||
+           ExternalUsesAsOriginalScalar.contains(Scalar) ||
+           any_of(
+               Scalar->users(),
+               [&, TTI = TTI](llvm::User *U) {
+                 if (ExternalUsesAsOriginalScalar.contains(U))
+                   return true;
+                 ArrayRef<TreeEntry *> UseEntries = getTreeEntries(U);
+                 return !UseEntries.empty() &&
+                        (E->State == TreeEntry::Vectorize ||
+                         E->State == TreeEntry::StridedVectorize) &&
+                        any_of(UseEntries, [&, TTI = TTI](TreeEntry *UseEntry) {
+                          return (UseEntry->State == TreeEntry::Vectorize ||
+                                  UseEntry->State ==
+                                      TreeEntry::StridedVectorize) &&
+                                 doesInTreeUserNeedToExtract(
+                                     Scalar, getRootEntryInstruction(*UseEntry),
+                                     TLI, TTI);
+                        });
+               })) &&
+          "Scalar with nullptr User must be registered in "
+          "ExternallyUsedValues map or remain as scalar in vectorized "
+          "instructions");
       if (auto *VecI = dyn_cast<Instruction>(Vec)) {
         if (auto *PHI = dyn_cast<PHINode>(VecI)) {
           if (PHI->getParent()->isLandingPad())
@@ -16870,7 +16838,7 @@ BoUpSLP::vectorizeTree(const ExtraValueToDebugLocsMap &ExternallyUsedValues,
           LLVM_DEBUG(dbgs() << "SLP: \tvalidating user:" << *U << ".\n");
 
           // It is legal to delete users in the ignorelist.
-          assert((getTreeEntry(U) ||
+          assert((isVectorized(U) ||
                   (UserIgnoreList && UserIgnoreList->contains(U)) ||
                   (isa_and_nonnull<Instruction>(U) &&
                    isDeleted(cast<Instruction>(U)))) &&
@@ -16892,7 +16860,7 @@ BoUpSLP::vectorizeTree(const ExtraValueToDebugLocsMap &ExternallyUsedValues,
   // Clear up reduction references, if any.
   if (UserIgnoreList) {
     for (Instruction *I : RemovedInsts) {
-      const TreeEntry *IE = getTreeEntry(I);
+      const TreeEntry *IE = getTreeEntries(I).front();
       if (IE->Idx != 0 &&
           !(VectorizableTree.front()->isGather() &&
             !IE->UserTreeIndices.empty() &&
@@ -17607,10 +17575,11 @@ void BoUpSLP::scheduleBlock(BlockScheduling *BS) {
   for (auto *I = BS->ScheduleStart; I != BS->ScheduleEnd;
        I = I->getNextNode()) {
     if (ScheduleData *SD = BS->getScheduleData(I)) {
-      [[maybe_unused]] TreeEntry *SDTE = getTreeEntry(SD->Inst);
+      [[maybe_unused]] ArrayRef<TreeEntry *> SDTEs = getTreeEntries(SD->Inst);
       assert((isVectorLikeInstWithConstOps(SD->Inst) ||
               SD->isPartOfBundle() ==
-                  (SDTE && !doesNotNeedToSchedule(SDTE->Scalars))) &&
+                  (!SDTEs.empty() &&
+                   !doesNotNeedToSchedule(SDTEs.front()->Scalars))) &&
              "scheduler and vectorizer bundle mismatch");
       SD->FirstInBundle->SchedulingPriority = Idx++;
 
@@ -17772,7 +17741,7 @@ bool BoUpSLP::collectValuesToDemote(
   auto IsPotentiallyTruncated = [&](Value *V, unsigned &BitWidth) -> bool {
     if (isa<PoisonValue>(V))
       return true;
-    if (MultiNodeScalars.contains(V))
+    if (getTreeEntries(V).size() > 1)
       return false;
     // For lat shuffle of sext/zext with many uses need to check the extra bit
     // for unsigned values, otherwise may have incorrect casting for reused
@@ -17834,14 +17803,14 @@ bool BoUpSLP::collectValuesToDemote(
   if (E.isGather() || !Visited.insert(&E).second ||
       any_of(E.Scalars, [&](Value *V) {
         return !isa<PoisonValue>(V) && all_of(V->users(), [&](User *U) {
-          return isa<InsertElementInst>(U) && !getTreeEntry(U);
+          return isa<InsertElementInst>(U) && !isVectorized(U);
         });
       }))
     return FinalAnalysis();
 
   if (any_of(E.Scalars, [&](Value *V) {
         return !all_of(V->users(), [=](User *U) {
-          return getTreeEntry(U) ||
+          return isVectorized(U) ||
                  (E.Idx == 0 && UserIgnoreList &&
                   UserIgnoreList->contains(U)) ||
                  (!isa<CmpInst>(U) && U->getType()->isSized() &&
@@ -18192,9 +18161,9 @@ void BoUpSLP::computeMinimumValueSizes() {
           return V->hasOneUse() || isa<Constant>(V) ||
                  (!V->hasNUsesOrMore(UsesLimit) &&
                   none_of(V->users(), [&](User *U) {
-                    const TreeEntry *TE = getTreeEntry(U);
+                    ArrayRef<TreeEntry *> TEs = getTreeEntries(U);
                     const TreeEntry *UserTE = E.UserTreeIndices.back().UserTE;
-                    if (TE == UserTE || !TE)
+                    if (TEs.empty() || is_contained(TEs, UserTE))
                       return false;
                     if (!isa<CastInst, BinaryOperator, FreezeInst, PHINode,
                              SelectInst>(U) ||
@@ -18203,8 +18172,11 @@ void BoUpSLP::computeMinimumValueSizes() {
                       return true;
                     unsigned UserTESz = DL->getTypeSizeInBits(
                         UserTE->Scalars.front()->getType());
-                    auto It = MinBWs.find(TE);
-                    if (It != MinBWs.end() && It->second.first > UserTESz)
+                    if (all_of(TEs, [&](const TreeEntry *TE) {
+                          auto It = MinBWs.find(TE);
+                          return It != MinBWs.end() &&
+                                 It->second.first > UserTESz;
+                        }))
                       return true;
                     return DL->getTypeSizeInBits(U->getType()) > UserTESz;
                   }));

>From a3c7455c453e8547f35f47285adb286d9e0725f8 Mon Sep 17 00:00:00 2001
From: Alexey Bataev <a.bataev at outlook.com>
Date: Wed, 29 Jan 2025 12:20:25 +0000
Subject: [PATCH 2/2] Fix formatting

Created using spr 1.3.5
---
 llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp | 1 -
 1 file changed, 1 deletion(-)

diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index 790c4dba0dc36b..4204f35d1a20d6 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -2963,7 +2963,6 @@ class BoUpSLP {
     return ScalarToTreeEntries.contains(V);
   }
 
-
   ~BoUpSLP();
 
 private:



More information about the llvm-commits mailing list