[llvm] [SLP]Support for tree throttling in SLP graphs with gathered loads (PR #177855)

Alexey Bataev via llvm-commits llvm-commits at lists.llvm.org
Sun Jan 25 06:45:58 PST 2026


https://github.com/alexey-bataev updated https://github.com/llvm/llvm-project/pull/177855

>From 696572c95d6780aec0daa413424a688aff714d86 Mon Sep 17 00:00:00 2001
From: Alexey Bataev <a.bataev at outlook.com>
Date: Sun, 25 Jan 2026 05:33:28 -0800
Subject: [PATCH 1/2] =?UTF-8?q?[=F0=9D=98=80=F0=9D=97=BD=F0=9D=97=BF]=20in?=
 =?UTF-8?q?itial=20version?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

Created using spr 1.3.7
---
 .../Transforms/Vectorize/SLPVectorizer.cpp    | 110 +++++++++++++++---
 .../X86/vec_list_bias-inseltpoison.ll         |  25 ++--
 2 files changed, 108 insertions(+), 27 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index 22b8ea481f2c3..9db577256d957 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -4519,6 +4519,10 @@ class slpvectorizer::BoUpSLP {
   LLVM_DUMP_METHOD void dumpVectorizableTree() const {
     for (unsigned Id = 0, IdE = VectorizableTree.size(); Id != IdE; ++Id) {
       VectorizableTree[Id]->dump();
+      if (TransformedToGatherNodes.contains(VectorizableTree[Id].get()))
+        dbgs() << "[[TRANSFORMED TO GATHER]]";
+      else if (DeletedNodes.contains(VectorizableTree[Id].get()))
+        dbgs() << "[[DELETED NODE]]";
       dbgs() << "\n";
     }
   }
@@ -16482,7 +16486,7 @@ InstructionCost BoUpSLP::calculateTreeCostAndTrimNonProfitable(
     ArrayRef<Value *> VectorizedVals) {
   SmallDenseMap<const TreeEntry *, InstructionCost> NodesCosts;
   SmallPtrSet<Value *, 4> CheckedExtracts;
-  SmallPtrSet<const TreeEntry *, 4> GatheredLoadsNodes;
+  SmallSetVector<TreeEntry *, 4> GatheredLoadsNodes;
   LLVM_DEBUG(dbgs() << "SLP: Calculating cost for tree of size "
                     << VectorizableTree.size() << ".\n");
   InstructionCost Cost = 0;
@@ -16533,11 +16537,6 @@ InstructionCost BoUpSLP::calculateTreeCostAndTrimNonProfitable(
   if (SLPCostThreshold.getNumOccurrences() > 0 && SLPCostThreshold < 0 &&
       Cost < -SLPCostThreshold)
     return Cost;
-  // Bail out, if gathered loads nodes are found.
-  // TODO: add analysis for gathered load to include their cost correctly into
-  // the related subtrees.
-  if (!GatheredLoadsNodes.empty())
-    return Cost;
   // The narrow non-profitable tree in loop? Skip, may cause regressions.
   constexpr unsigned PartLimit = 2;
   const unsigned Sz =
@@ -16551,17 +16550,38 @@ InstructionCost BoUpSLP::calculateTreeCostAndTrimNonProfitable(
     return Cost;
   SmallVector<std::pair<InstructionCost, SmallVector<unsigned>>> SubtreeCosts(
       VectorizableTree.size());
+  auto UpdateParentNodes =
+      [&](const TreeEntry *UserTE, const TreeEntry *TE, InstructionCost C,
+          SmallDenseSet<std::pair<const TreeEntry *, const TreeEntry *>, 4>
+              &VisitedUser,
+          bool AddToList = true) {
+        while (UserTE &&
+               VisitedUser.insert(std::make_pair(TE, UserTE)).second) {
+          SubtreeCosts[UserTE->Idx].first += C;
+          if (AddToList)
+            SubtreeCosts[UserTE->Idx].second.push_back(TE->Idx);
+          UserTE = UserTE->UserTreeIndex.UserTE;
+        }
+      };
   for (const std::unique_ptr<TreeEntry> &Ptr : VectorizableTree) {
     TreeEntry &TE = *Ptr;
     InstructionCost C = NodesCosts.at(&TE);
     SubtreeCosts[TE.Idx].first += C;
-    const TreeEntry *UserTE = TE.UserTreeIndex.UserTE;
-    while (UserTE) {
-      SubtreeCosts[UserTE->Idx].first += C;
-      SubtreeCosts[UserTE->Idx].second.push_back(TE.Idx);
-      UserTE = UserTE->UserTreeIndex.UserTE;
+    if (const TreeEntry *UserTE = TE.UserTreeIndex.UserTE) {
+      SmallDenseSet<std::pair<const TreeEntry *, const TreeEntry *>, 4>
+          VisitedUser;
+      UpdateParentNodes(UserTE, &TE, C, VisitedUser);
+    }
+  }
+  SmallDenseSet<std::pair<const TreeEntry *, const TreeEntry *>, 4> Visited;
+  for (TreeEntry *TE : GatheredLoadsNodes) {
+    InstructionCost C = SubtreeCosts[TE->Idx].first;
+    for (Value *V : TE->Scalars) {
+      for (const TreeEntry *BVTE : ValueToGatherNodes.lookup(V))
+        UpdateParentNodes(BVTE, TE, C, Visited, /*AddToList=*/false);
     }
   }
+  Visited.clear();
   using CostIndicesTy =
       std::pair<TreeEntry *, std::pair<InstructionCost, SmallVector<unsigned>>>;
   struct FirstGreater {
@@ -16583,6 +16603,7 @@ InstructionCost BoUpSLP::calculateTreeCostAndTrimNonProfitable(
       (Worklist.top().first->Idx == 0 || Worklist.top().first->Idx == 1))
     return Cost;
 
+  constexpr TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
   bool Changed = false;
   while (!Worklist.empty() && Worklist.top().second.first > 0) {
     TreeEntry *TE = Worklist.top().first;
@@ -16624,7 +16645,6 @@ InstructionCost BoUpSLP::calculateTreeCostAndTrimNonProfitable(
       if (isConstant(V))
         DemandedElts.clearBit(Idx);
     }
-    constexpr TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
 
     Type *ScalarTy = getValueType(TE->Scalars.front());
     auto *VecTy = getWidenedType(ScalarTy, Sz);
@@ -16663,7 +16683,7 @@ InstructionCost BoUpSLP::calculateTreeCostAndTrimNonProfitable(
     // Erase subtree if it is non-profitable.
     if (SubtreeCost > GatherCost) {
       // If the remaining tree is just a buildvector - exit, it will cause
-      // enless attempts to vectorize.
+      // endless attempts to vectorize.
       if (VectorizableTree.front()->hasState() &&
           VectorizableTree.front()->getOpcode() == Instruction::InsertElement &&
           TE->Idx == 1)
@@ -16694,6 +16714,68 @@ InstructionCost BoUpSLP::calculateTreeCostAndTrimNonProfitable(
   if (!Changed)
     return SubtreeCosts.front().first;
 
+  SmallPtrSet<TreeEntry *, 4> GatheredLoadsToDelete;
+  InstructionCost LoadsExtractsCost = 0;
+  // Check if all loads of gathered loads nodes are marked for deletion. In this
+  // case the whole gathered loads subtree must be deleted.
+  // Also, try to account for extracts, which might be required, if only part of
+  // gathered load must be vectorized. Keep partially vectorized nodes, if
+  // extracts are cheaper than gathers.
+  for (TreeEntry *TE : GatheredLoadsNodes) {
+    if (DeletedNodes.contains(TE) || TransformedToGatherNodes.contains(TE))
+      continue;
+    GatheredLoadsToDelete.insert(TE);
+    APInt DemandedElts = APInt::getZero(TE->getVectorFactor());
+    // All loads are removed from gathered? Need to delete the subtree.
+    SmallDenseMap<const TreeEntry *, SmallVector<Value *>> ValuesToInsert;
+    for (Value *V : TE->Scalars) {
+      unsigned Pos = TE->findLaneForValue(V);
+      for (const TreeEntry *BVE : ValueToGatherNodes.lookup(V)) {
+        if (DeletedNodes.contains(BVE))
+          continue;
+        DemandedElts.setBit(Pos);
+        ValuesToInsert.try_emplace(BVE).first->second.push_back(V);
+      }
+    }
+    if (!DemandedElts.isZero()) {
+      Type *ScalarTy = TE->Scalars.front()->getType();
+      auto *VecTy = getWidenedType(ScalarTy, TE->getVectorFactor());
+      InstructionCost ExtractsCost = ::getScalarizationOverhead(
+          *TTI, ScalarTy, VecTy, DemandedElts,
+          /*Insert=*/false, /*Extract=*/true, CostKind);
+      InstructionCost BVCost = 0;
+      for (const auto &[BVE, Values] : ValuesToInsert) {
+        APInt BVDemandedElts = APInt::getZero(BVE->getVectorFactor());
+        for (Value *V : Values) {
+          unsigned Pos = BVE->findLaneForValue(V);
+          BVDemandedElts.setBit(Pos);
+        }
+        auto *BVVecTy = getWidenedType(ScalarTy, BVE->getVectorFactor());
+        BVCost += ::getScalarizationOverhead(
+            *TTI, ScalarTy, BVVecTy, BVDemandedElts,
+            /*Insert=*/true, /*Extract=*/false, CostKind,
+            BVDemandedElts.isAllOnes(), Values);
+      }
+      if (ExtractsCost < BVCost) {
+        LoadsExtractsCost += ExtractsCost;
+        GatheredLoadsToDelete.erase(TE);
+        continue;
+      }
+      LoadsExtractsCost += BVCost;
+    }
+    NodesCosts.erase(TE);
+  }
+
+  // Deleted all subtrees rooted at gathered loads nodes.
+  for (std::unique_ptr<TreeEntry> &TE : VectorizableTree) {
+    if (TE->UserTreeIndex &&
+        GatheredLoadsToDelete.contains(TE->UserTreeIndex.UserTE)) {
+      DeletedNodes.insert(TE.get());
+      NodesCosts.erase(TE.get());
+      GatheredLoadsToDelete.insert(TE.get());
+    }
+  }
+
   for (std::unique_ptr<TreeEntry> &TE : VectorizableTree) {
     if (!TE->UserTreeIndex && TransformedToGatherNodes.contains(TE.get())) {
       assert(TE->getOpcode() == Instruction::Load && "Expected load only.");
@@ -16717,7 +16799,7 @@ InstructionCost BoUpSLP::calculateTreeCostAndTrimNonProfitable(
                       << ".\n"
                       << "SLP: Current total cost = " << Cost << "\n");
   }
-  if (NewCost >= Cost) {
+  if (NewCost + LoadsExtractsCost >= Cost) {
     DeletedNodes.clear();
     TransformedToGatherNodes.clear();
     NewCost = Cost;
diff --git a/llvm/test/Transforms/SLPVectorizer/X86/vec_list_bias-inseltpoison.ll b/llvm/test/Transforms/SLPVectorizer/X86/vec_list_bias-inseltpoison.ll
index 2cc2f28ccf6d5..e3a6020a542fb 100644
--- a/llvm/test/Transforms/SLPVectorizer/X86/vec_list_bias-inseltpoison.ll
+++ b/llvm/test/Transforms/SLPVectorizer/X86/vec_list_bias-inseltpoison.ll
@@ -25,6 +25,7 @@ define void @test(ptr nocapture %t2) {
 ; CHECK-NEXT:    [[T24:%.*]] = add nsw i32 [[T23]], [[T21]]
 ; CHECK-NEXT:    [[T25:%.*]] = sub nsw i32 [[T21]], [[T23]]
 ; CHECK-NEXT:    [[T27:%.*]] = sub nsw i32 [[T3]], [[T24]]
+; CHECK-NEXT:    [[T32:%.*]] = mul nsw i32 [[T27]], 6270
 ; CHECK-NEXT:    [[T37:%.*]] = add nsw i32 [[T25]], [[T11]]
 ; CHECK-NEXT:    [[T38:%.*]] = add nsw i32 [[T17]], [[T5]]
 ; CHECK-NEXT:    [[T39:%.*]] = add nsw i32 [[T37]], [[T38]]
@@ -33,6 +34,7 @@ define void @test(ptr nocapture %t2) {
 ; CHECK-NEXT:    [[T42:%.*]] = mul nsw i32 [[T17]], 16819
 ; CHECK-NEXT:    [[T47:%.*]] = mul nsw i32 [[T37]], -16069
 ; CHECK-NEXT:    [[T48:%.*]] = mul nsw i32 [[T38]], -3196
+; CHECK-NEXT:    [[T49:%.*]] = add nsw i32 [[T40]], [[T47]]
 ; CHECK-NEXT:    [[TMP1:%.*]] = load <2 x i32>, ptr [[T8]], align 4
 ; CHECK-NEXT:    [[T15:%.*]] = load i32, ptr [[T14]], align 4
 ; CHECK-NEXT:    [[T9:%.*]] = load i32, ptr [[T8]], align 4
@@ -40,20 +42,17 @@ define void @test(ptr nocapture %t2) {
 ; CHECK-NEXT:    [[T30:%.*]] = add nsw i32 [[T27]], [[T29]]
 ; CHECK-NEXT:    [[T31:%.*]] = mul nsw i32 [[T30]], 4433
 ; CHECK-NEXT:    [[T34:%.*]] = mul nsw i32 [[T29]], -15137
-; CHECK-NEXT:    [[TMP2:%.*]] = shufflevector <2 x i32> [[TMP1]], <2 x i32> poison, <4 x i32> <i32 1, i32 poison, i32 poison, i32 poison>
-; CHECK-NEXT:    [[TMP3:%.*]] = insertelement <4 x i32> [[TMP2]], i32 [[T40]], i32 1
-; CHECK-NEXT:    [[TMP4:%.*]] = insertelement <4 x i32> [[TMP3]], i32 [[T27]], i32 2
-; CHECK-NEXT:    [[TMP5:%.*]] = insertelement <4 x i32> [[TMP4]], i32 [[T47]], i32 3
-; CHECK-NEXT:    [[TMP6:%.*]] = shufflevector <2 x i32> [[TMP1]], <2 x i32> poison, <4 x i32> <i32 0, i32 poison, i32 poison, i32 poison>
-; CHECK-NEXT:    [[TMP7:%.*]] = shufflevector <4 x i32> <i32 poison, i32 poison, i32 6270, i32 poison>, <4 x i32> [[TMP6]], <4 x i32> <i32 4, i32 poison, i32 2, i32 poison>
-; CHECK-NEXT:    [[TMP8:%.*]] = insertelement <4 x i32> [[TMP7]], i32 [[T48]], i32 1
-; CHECK-NEXT:    [[TMP9:%.*]] = insertelement <4 x i32> [[TMP8]], i32 [[T40]], i32 3
-; CHECK-NEXT:    [[TMP10:%.*]] = add nsw <4 x i32> [[TMP5]], [[TMP9]]
-; CHECK-NEXT:    [[TMP11:%.*]] = mul nsw <4 x i32> [[TMP5]], [[TMP9]]
-; CHECK-NEXT:    [[TMP12:%.*]] = shufflevector <4 x i32> [[TMP10]], <4 x i32> [[TMP11]], <4 x i32> <i32 0, i32 1, i32 6, i32 3>
-; CHECK-NEXT:    [[T701:%.*]] = shufflevector <4 x i32> [[TMP12]], <4 x i32> poison, <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 0, i32 1, i32 poison, i32 3>
+; CHECK-NEXT:    [[TMP2:%.*]] = shufflevector <2 x i32> [[TMP1]], <2 x i32> poison, <2 x i32> <i32 1, i32 poison>
+; CHECK-NEXT:    [[TMP3:%.*]] = insertelement <2 x i32> [[TMP2]], i32 [[T40]], i32 1
+; CHECK-NEXT:    [[TMP4:%.*]] = insertelement <2 x i32> [[TMP1]], i32 [[T48]], i32 1
+; CHECK-NEXT:    [[TMP5:%.*]] = add nsw <2 x i32> [[TMP3]], [[TMP4]]
+; CHECK-NEXT:    [[TMP6:%.*]] = shufflevector <2 x i32> [[TMP5]], <2 x i32> poison, <8 x i32> <i32 0, i32 1, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison>
+; CHECK-NEXT:    [[T67:%.*]] = insertelement <8 x i32> [[TMP6]], i32 [[T32]], i32 2
+; CHECK-NEXT:    [[T68:%.*]] = insertelement <8 x i32> [[T67]], i32 [[T49]], i32 3
+; CHECK-NEXT:    [[T701:%.*]] = shufflevector <8 x i32> [[T68]], <8 x i32> [[TMP6]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 8, i32 9, i32 poison, i32 poison>
 ; CHECK-NEXT:    [[T71:%.*]] = insertelement <8 x i32> [[T701]], i32 [[T34]], i32 6
-; CHECK-NEXT:    [[T76:%.*]] = shl <8 x i32> [[T71]], splat (i32 3)
+; CHECK-NEXT:    [[T72:%.*]] = insertelement <8 x i32> [[T71]], i32 [[T49]], i32 7
+; CHECK-NEXT:    [[T76:%.*]] = shl <8 x i32> [[T72]], splat (i32 3)
 ; CHECK-NEXT:    store <8 x i32> [[T76]], ptr [[T2]], align 4
 ; CHECK-NEXT:    ret void
 ;

>From b98e16096457b928d777ba3447aca14bcf3608a3 Mon Sep 17 00:00:00 2001
From: Alexey Bataev <a.bataev at outlook.com>
Date: Sun, 25 Jan 2026 06:45:49 -0800
Subject: [PATCH 2/2] Fix a crash in TTI

Created using spr 1.3.7
---
 llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp | 5 ++++-
 1 file changed, 4 insertions(+), 1 deletion(-)

diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index 9db577256d957..fa9b86484c21f 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -16746,15 +16746,18 @@ InstructionCost BoUpSLP::calculateTreeCostAndTrimNonProfitable(
       InstructionCost BVCost = 0;
       for (const auto &[BVE, Values] : ValuesToInsert) {
         APInt BVDemandedElts = APInt::getZero(BVE->getVectorFactor());
+        SmallVector<Value *> BVValues(BVE->getVectorFactor(),
+                                      PoisonValue::get(ScalarTy));
         for (Value *V : Values) {
           unsigned Pos = BVE->findLaneForValue(V);
+          BVValues[Pos] = V;
           BVDemandedElts.setBit(Pos);
         }
         auto *BVVecTy = getWidenedType(ScalarTy, BVE->getVectorFactor());
         BVCost += ::getScalarizationOverhead(
             *TTI, ScalarTy, BVVecTy, BVDemandedElts,
             /*Insert=*/true, /*Extract=*/false, CostKind,
-            BVDemandedElts.isAllOnes(), Values);
+            BVDemandedElts.isAllOnes(), BVValues);
       }
       if (ExtractsCost < BVCost) {
         LoadsExtractsCost += ExtractsCost;



More information about the llvm-commits mailing list