[llvm] [SLP]Improve minbitwidth analysis. (PR #84334)

via llvm-commits llvm-commits at lists.llvm.org
Thu Mar 7 07:37:06 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-transforms

Author: Alexey Bataev (alexey-bataev)

<details>
<summary>Changes</summary>

This improves overall analysis for minbitwidth in SLP. It allows to
analyze the trees with store/insertelement root nodes. Also, instead of
using single minbitwidth, detected from the very first analysis stage,
it tries to detect the best one for each trunc/ext subtree in the graph
and use it for the subtree.
Results in better code and less vector register pressure.

Metric: size..text

Program                                                                                                                                                size..text
                                                                                                                                                       results     results0    diff
                                                                      test-suite :: SingleSource/Benchmarks/Adobe-C++/simple_types_loop_invariant.test    92549.00    92609.00  0.1%
                                                                                  test-suite :: External/SPEC/CINT2017speed/625.x264_s/625.x264_s.test   663381.00   663493.00  0.0%
                                                                                   test-suite :: External/SPEC/CINT2017rate/525.x264_r/525.x264_r.test   663381.00   663493.00  0.0%
                                                                                               test-suite :: MultiSource/Benchmarks/Bullet/bullet.test   307182.00   307214.00  0.0%
                                                                             test-suite :: External/SPEC/CFP2017speed/638.imagick_s/638.imagick_s.test  1394420.00  1394484.00  0.0%
                                                                              test-suite :: External/SPEC/CFP2017rate/538.imagick_r/538.imagick_r.test  1394420.00  1394484.00  0.0%
                                                                                test-suite :: External/SPEC/CFP2017rate/510.parest_r/510.parest_r.test  2040257.00  2040273.00  0.0%

                                                                              test-suite :: External/SPEC/CFP2017rate/526.blender_r/526.blender_r.test 12396098.00 12395858.00 -0.0%
                                                                                         test-suite :: External/SPEC/CINT2006/445.gobmk/445.gobmk.test   909944.00   909768.00 -0.0%

SingleSource/Benchmarks/Adobe-C++/simple_types_loop_invariant - 4 scalar
instructions remain scalar (good).
Spec2017/x264 - the whole function idct4x4dc is vectorized using <16
x i16> instead of <16 x i32>, also zext/trunc are removed. In other
places last vector zext/sext removed and replaced by
extractelement + scalar zext/sext pair.
MultiSource/Benchmarks/Bullet/bullet - reduce or <4 x i32> replaced by
reduce or <4 x i8>
Spec2017/imagick - Removed extra zext from 2 packs of the operations.
Spec2017/parest - Removed extra zext, replaced by extractelement+scalar
zext
Spec2017/blender - the whole bunch of vector zext/sext replaced by
extractelement+scalar zext/sext, some extra code vectorized in smaller
types.
Spec2006/gobmk - fixed cost estimation, some small code remains scalar.


---

Patch is 71.15 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/84334.diff


15 Files Affected:

- (modified) llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp (+440-194) 
- (modified) llvm/test/Transforms/SLPVectorizer/AArch64/ext-trunc.ll (+5-4) 
- (modified) llvm/test/Transforms/SLPVectorizer/AArch64/getelementptr2.ll (+2-2) 
- (modified) llvm/test/Transforms/SLPVectorizer/AArch64/reduce-add-i64.ll (+5-15) 
- (modified) llvm/test/Transforms/SLPVectorizer/RISCV/reductions.ll (+4-3) 
- (modified) llvm/test/Transforms/SLPVectorizer/X86/PR35777.ll (+5-4) 
- (modified) llvm/test/Transforms/SLPVectorizer/X86/int-bitcast-minbitwidth.ll (+1-1) 
- (modified) llvm/test/Transforms/SLPVectorizer/X86/minbitwidth-multiuse-with-insertelement.ll (+8-9) 
- (modified) llvm/test/Transforms/SLPVectorizer/X86/minbitwidth-transformed-operand.ll (+13-8) 
- (modified) llvm/test/Transforms/SLPVectorizer/X86/minimum-sizes.ll (+24-19) 
- (modified) llvm/test/Transforms/SLPVectorizer/X86/phi-undef-input.ll (+12-12) 
- (modified) llvm/test/Transforms/SLPVectorizer/X86/resched.ll (+16-16) 
- (modified) llvm/test/Transforms/SLPVectorizer/X86/reused-reductions-with-minbitwidth.ll (+4-6) 
- (modified) llvm/test/Transforms/SLPVectorizer/X86/store-insertelement-minbitwidth.ll (+12-10) 
- (modified) llvm/test/Transforms/SLPVectorizer/alt-cmp-vectorize.ll (+2-2) 


``````````diff
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index 36dc9094538ae9..1889bc09e85028 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -1085,6 +1085,9 @@ class BoUpSLP {
       BS->clear();
     }
     MinBWs.clear();
+    ReductionBitWidth = 0;
+    CastMaxMinBWSizes.reset();
+    TruncNodes.clear();
     InstrElementSize.clear();
     UserIgnoreList = nullptr;
     PostponedGathers.clear();
@@ -2287,6 +2290,7 @@ class BoUpSLP {
   void clearReductionData() {
     AnalyzedReductionsRoots.clear();
     AnalyzedReductionVals.clear();
+    AnalyzedMinBWVals.clear();
   }
   /// Checks if the given value is gathered in one of the nodes.
   bool isAnyGathered(const SmallDenseSet<Value *> &Vals) const {
@@ -2307,9 +2311,11 @@ class BoUpSLP {
   /// constant and to be demoted. Required to correctly identify constant nodes
   /// to be demoted.
   bool collectValuesToDemote(
-      Value *V, SmallVectorImpl<Value *> &ToDemote,
+      Value *V, bool IsProfitableToDemoteRoot, unsigned &BitWidth,
+      SmallVectorImpl<Value *> &ToDemote,
       DenseMap<Instruction *, SmallVector<unsigned>> &DemotedConsts,
-      SmallVectorImpl<Value *> &Roots, DenseSet<Value *> &Visited) const;
+      DenseSet<Value *> &Visited, unsigned &MaxDepthLevel,
+      bool &IsProfitableToDemote) const;
 
   /// Check if the operands on the edges \p Edges of the \p UserTE allows
   /// reordering (i.e. the operands can be reordered because they have only one
@@ -2375,6 +2381,10 @@ class BoUpSLP {
   /// \ returns the graph entry for the \p Idx operand of the \p E entry.
   const TreeEntry *getOperandEntry(const TreeEntry *E, unsigned Idx) const;
 
+  /// \returns Cast context for the given graph node.
+  TargetTransformInfo::CastContextHint
+  getCastContextHint(const TreeEntry &TE) const;
+
   /// \returns the cost of the vectorizable entry.
   InstructionCost getEntryCost(const TreeEntry *E,
                                ArrayRef<Value *> VectorizedVals,
@@ -2925,11 +2935,18 @@ class BoUpSLP {
       }
       assert(!BundleMember && "Bundle and VL out of sync");
     } else {
-      MustGather.insert(VL.begin(), VL.end());
       // Build a map for gathered scalars to the nodes where they are used.
+      bool AllConstsOrCasts = true;
       for (Value *V : VL)
-        if (!isConstant(V))
+        if (!isConstant(V)) {
+          auto *I = dyn_cast<CastInst>(V);
+          AllConstsOrCasts &= I && I->getType()->isIntegerTy();
           ValueToGatherNodes.try_emplace(V).first->getSecond().insert(Last);
+        }
+      if (AllConstsOrCasts)
+        CastMaxMinBWSizes =
+            std::make_pair(std::numeric_limits<unsigned>::max(), 1);
+      MustGather.insert(VL.begin(), VL.end());
     }
 
     if (UserTreeIdx.UserTE)
@@ -3054,6 +3071,10 @@ class BoUpSLP {
   /// Set of hashes for the list of reduction values already being analyzed.
   DenseSet<size_t> AnalyzedReductionVals;
 
+  /// Values, already been analyzed for mininmal bitwidth and found to be
+  /// non-profitable.
+  DenseSet<Value *> AnalyzedMinBWVals;
+
   /// A list of values that need to extracted out of the tree.
   /// This list holds pairs of (Internal Scalar : External User). External User
   /// can be nullptr, it means that this Internal Scalar will be used later,
@@ -3629,6 +3650,18 @@ class BoUpSLP {
   /// value must be signed-extended, rather than zero-extended, back to its
   /// original width.
   DenseMap<const TreeEntry *, std::pair<uint64_t, bool>> MinBWs;
+
+  /// Final size of the reduced vector, if the current graph represents the
+  /// input for the reduction and it was possible to narrow the size of the
+  /// reduction.
+  unsigned ReductionBitWidth = 0;
+
+  /// If the tree contains any zext/sext/trunc nodes, contains max-min pair of
+  /// type sizes, used in the tree.
+  std::optional<std::pair<unsigned, unsigned>> CastMaxMinBWSizes;
+
+  /// Indices of the vectorized trunc nodes.
+  DenseSet<unsigned> TruncNodes;
 };
 
 } // end namespace slpvectorizer
@@ -6539,8 +6572,29 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
     case Instruction::Trunc:
     case Instruction::FPTrunc:
     case Instruction::BitCast: {
+      auto [PrevMaxBW, PrevMinBW] = CastMaxMinBWSizes.value_or(
+          std::make_pair(std::numeric_limits<unsigned>::min(),
+                         std::numeric_limits<unsigned>::max()));
+      if (ShuffleOrOp == Instruction::ZExt ||
+          ShuffleOrOp == Instruction::SExt) {
+        CastMaxMinBWSizes = std::make_pair(
+            std::max<unsigned>(DL->getTypeSizeInBits(VL0->getType()),
+                               PrevMaxBW),
+            std::min<unsigned>(
+                DL->getTypeSizeInBits(VL0->getOperand(0)->getType()),
+                PrevMinBW));
+      } else if (ShuffleOrOp == Instruction::Trunc) {
+        CastMaxMinBWSizes = std::make_pair(
+            std::max<unsigned>(
+                DL->getTypeSizeInBits(VL0->getOperand(0)->getType()),
+                PrevMaxBW),
+            std::min<unsigned>(DL->getTypeSizeInBits(VL0->getType()),
+                               PrevMinBW));
+        TruncNodes.insert(VectorizableTree.size());
+      }
       TreeEntry *TE = newTreeEntry(VL, Bundle /*vectorized*/, S, UserTreeIdx,
                                    ReuseShuffleIndicies);
+
       LLVM_DEBUG(dbgs() << "SLP: added a vector of casts.\n");
 
       TE->setOperandsInOrder();
@@ -8362,6 +8416,22 @@ const BoUpSLP::TreeEntry *BoUpSLP::getOperandEntry(const TreeEntry *E,
   return It->get();
 }
 
+TTI::CastContextHint BoUpSLP::getCastContextHint(const TreeEntry &TE) const {
+  if (TE.State == TreeEntry::ScatterVectorize ||
+      TE.State == TreeEntry::StridedVectorize)
+    return TTI::CastContextHint::GatherScatter;
+  if (TE.State == TreeEntry::Vectorize && TE.getOpcode() == Instruction::Load &&
+      !TE.isAltShuffle()) {
+    if (TE.ReorderIndices.empty())
+      return TTI::CastContextHint::Normal;
+    SmallVector<int> Mask;
+    inversePermutation(TE.ReorderIndices, Mask);
+    if (ShuffleVectorInst::isReverseMask(Mask, Mask.size()))
+      return TTI::CastContextHint::Reversed;
+  }
+  return TTI::CastContextHint::None;
+}
+
 InstructionCost
 BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
                       SmallPtrSetImpl<Value *> &CheckedExtracts) {
@@ -8384,6 +8454,7 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
   // If we have computed a smaller type for the expression, update VecTy so
   // that the costs will be accurate.
   auto It = MinBWs.find(E);
+  Type *OrigScalarTy = ScalarTy;
   if (It != MinBWs.end()) {
     ScalarTy = IntegerType::get(F->getContext(), It->second.first);
     VecTy = FixedVectorType::get(ScalarTy, VL.size());
@@ -8441,24 +8512,11 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
     UsedScalars.set(I);
   }
   auto GetCastContextHint = [&](Value *V) {
-    if (const TreeEntry *OpTE = getTreeEntry(V)) {
-      if (OpTE->State == TreeEntry::ScatterVectorize ||
-          OpTE->State == TreeEntry::StridedVectorize)
-        return TTI::CastContextHint::GatherScatter;
-      if (OpTE->State == TreeEntry::Vectorize &&
-          OpTE->getOpcode() == Instruction::Load && !OpTE->isAltShuffle()) {
-        if (OpTE->ReorderIndices.empty())
-          return TTI::CastContextHint::Normal;
-        SmallVector<int> Mask;
-        inversePermutation(OpTE->ReorderIndices, Mask);
-        if (ShuffleVectorInst::isReverseMask(Mask, Mask.size()))
-          return TTI::CastContextHint::Reversed;
-      }
-    } else {
-      InstructionsState SrcState = getSameOpcode(E->getOperand(0), *TLI);
-      if (SrcState.getOpcode() == Instruction::Load && !SrcState.isAltShuffle())
-        return TTI::CastContextHint::GatherScatter;
-    }
+    if (const TreeEntry *OpTE = getTreeEntry(V))
+      return getCastContextHint(*OpTE);
+    InstructionsState SrcState = getSameOpcode(E->getOperand(0), *TLI);
+    if (SrcState.getOpcode() == Instruction::Load && !SrcState.isAltShuffle())
+      return TTI::CastContextHint::GatherScatter;
     return TTI::CastContextHint::None;
   };
   auto GetCostDiff =
@@ -8507,8 +8565,6 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
               TTI::CastContextHint CCH = GetCastContextHint(VL0);
               VecCost += TTI->getCastInstrCost(VecOpcode, UserVecTy, VecTy, CCH,
                                                CostKind);
-              ScalarCost += Sz * TTI->getCastInstrCost(VecOpcode, UserScalarTy,
-                                                       ScalarTy, CCH, CostKind);
             }
           }
         }
@@ -8525,7 +8581,7 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
     InstructionCost ScalarCost = 0;
     InstructionCost VecCost = 0;
     std::tie(ScalarCost, VecCost) = getGEPCosts(
-        *TTI, Ptrs, BasePtr, E->getOpcode(), CostKind, ScalarTy, VecTy);
+        *TTI, Ptrs, BasePtr, E->getOpcode(), CostKind, OrigScalarTy, VecTy);
     LLVM_DEBUG(dumpTreeCosts(E, 0, VecCost, ScalarCost,
                              "Calculated GEPs cost for Tree"));
 
@@ -8572,7 +8628,7 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
           NumElts = ATy->getNumElements();
         else
           NumElts = AggregateTy->getStructNumElements();
-        SrcVecTy = FixedVectorType::get(ScalarTy, NumElts);
+        SrcVecTy = FixedVectorType::get(OrigScalarTy, NumElts);
       }
       if (I->hasOneUse()) {
         Instruction *Ext = I->user_back();
@@ -8740,13 +8796,7 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
       }
     }
     auto GetScalarCost = [&](unsigned Idx) -> InstructionCost {
-      // Do not count cost here if minimum bitwidth is in effect and it is just
-      // a bitcast (here it is just a noop).
-      if (VecOpcode != Opcode && VecOpcode == Instruction::BitCast)
-        return TTI::TCC_Free;
-      auto *VI = VL0->getOpcode() == Opcode
-                     ? cast<Instruction>(UniqueValues[Idx])
-                     : nullptr;
+      auto *VI = cast<Instruction>(UniqueValues[Idx]);
       return TTI->getCastInstrCost(Opcode, VL0->getType(),
                                    VL0->getOperand(0)->getType(),
                                    TTI::getCastContextHint(VI), CostKind, VI);
@@ -8789,7 +8839,7 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
                                        ? CmpInst::BAD_FCMP_PREDICATE
                                        : CmpInst::BAD_ICMP_PREDICATE;
 
-      return TTI->getCmpSelInstrCost(E->getOpcode(), ScalarTy,
+      return TTI->getCmpSelInstrCost(E->getOpcode(), OrigScalarTy,
                                      Builder.getInt1Ty(), CurrentPred, CostKind,
                                      VI);
     };
@@ -8844,7 +8894,7 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
       TTI::OperandValueInfo Op2Info =
           TTI::getOperandInfo(VI->getOperand(OpIdx));
       SmallVector<const Value *> Operands(VI->operand_values());
-      return TTI->getArithmeticInstrCost(ShuffleOrOp, ScalarTy, CostKind,
+      return TTI->getArithmeticInstrCost(ShuffleOrOp, OrigScalarTy, CostKind,
                                          Op1Info, Op2Info, Operands, VI);
     };
     auto GetVectorCost = [=](InstructionCost CommonCost) {
@@ -8863,9 +8913,9 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
   case Instruction::Load: {
     auto GetScalarCost = [&](unsigned Idx) {
       auto *VI = cast<LoadInst>(UniqueValues[Idx]);
-      return TTI->getMemoryOpCost(Instruction::Load, ScalarTy, VI->getAlign(),
-                                  VI->getPointerAddressSpace(), CostKind,
-                                  TTI::OperandValueInfo(), VI);
+      return TTI->getMemoryOpCost(Instruction::Load, OrigScalarTy,
+                                  VI->getAlign(), VI->getPointerAddressSpace(),
+                                  CostKind, TTI::OperandValueInfo(), VI);
     };
     auto *LI0 = cast<LoadInst>(VL0);
     auto GetVectorCost = [&](InstructionCost CommonCost) {
@@ -8908,9 +8958,9 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
     auto GetScalarCost = [=](unsigned Idx) {
       auto *VI = cast<StoreInst>(VL[Idx]);
       TTI::OperandValueInfo OpInfo = TTI::getOperandInfo(VI->getValueOperand());
-      return TTI->getMemoryOpCost(Instruction::Store, ScalarTy, VI->getAlign(),
-                                  VI->getPointerAddressSpace(), CostKind,
-                                  OpInfo, VI);
+      return TTI->getMemoryOpCost(Instruction::Store, OrigScalarTy,
+                                  VI->getAlign(), VI->getPointerAddressSpace(),
+                                  CostKind, OpInfo, VI);
     };
     auto *BaseSI =
         cast<StoreInst>(IsReorder ? VL[E->ReorderIndices.front()] : VL0);
@@ -9772,6 +9822,44 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) {
     Cost -= InsertCost;
   }
 
+  // Add the cost for reduced value resize (if required).
+  if (ReductionBitWidth != 0) {
+    assert(UserIgnoreList && "Expected reduction tree.");
+    const TreeEntry &E = *VectorizableTree.front().get();
+    auto It = MinBWs.find(&E);
+    if (It != MinBWs.end() && It->second.first != ReductionBitWidth) {
+      unsigned SrcSize = It->second.first;
+      unsigned DstSize = ReductionBitWidth;
+      unsigned Opcode = Instruction::Trunc;
+      if (SrcSize < DstSize)
+        Opcode = It->second.second ? Instruction::SExt : Instruction::ZExt;
+      auto *SrcVecTy =
+          FixedVectorType::get(Builder.getIntNTy(SrcSize), E.getVectorFactor());
+      auto *DstVecTy =
+          FixedVectorType::get(Builder.getIntNTy(DstSize), E.getVectorFactor());
+      TTI::CastContextHint CCH = getCastContextHint(E);
+      InstructionCost CastCost;
+      switch (E.getOpcode()) {
+      case Instruction::SExt:
+      case Instruction::ZExt:
+      case Instruction::Trunc: {
+        const TreeEntry *OpTE = getOperandEntry(&E, 0);
+        CCH = getCastContextHint(*OpTE);
+        break;
+      }
+      default:
+        break;
+      }
+      CastCost += TTI->getCastInstrCost(Opcode, DstVecTy, SrcVecTy, CCH,
+                                        TTI::TCK_RecipThroughput);
+      Cost += CastCost;
+      LLVM_DEBUG(dbgs() << "SLP: Adding cost " << CastCost
+                        << " for final resize for reduction from " << SrcVecTy
+                        << " to " << DstVecTy << "\n";
+                 dbgs() << "SLP: Current total cost = " << Cost << "\n");
+    }
+  }
+
 #ifndef NDEBUG
   SmallString<256> Str;
   {
@@ -9992,6 +10080,30 @@ BoUpSLP::isGatherShuffledSingleRegisterEntry(
   // tree node for each gathered value - we have just a permutation of the
   // single vector. If we have 2 different sets, we're in situation where we
   // have a permutation of 2 input vectors.
+  // Filter out entries with larger bitwidth of elements.
+  Type *ScalarTy = VL.front()->getType();
+  unsigned BitWidth = 0;
+  if (ScalarTy->isIntegerTy()) {
+    // Check if the used TEs supposed to be resized and choose the best
+    // candidates.
+    BitWidth = DL->getTypeStoreSize(ScalarTy);
+    if (TEUseEI.UserTE->getOpcode() != Instruction::Select ||
+        TEUseEI.EdgeIdx != 0) {
+      auto UserIt = MinBWs.find(TEUseEI.UserTE);
+      if (UserIt != MinBWs.end())
+        BitWidth = UserIt->second.second;
+    }
+  }
+  auto CheckBitwidth = [&](const TreeEntry &TE) {
+    Type *ScalarTy = TE.Scalars.front()->getType();
+    if (!ScalarTy->isIntegerTy())
+      return true;
+    unsigned TEBitWidth = DL->getTypeStoreSize(ScalarTy);
+    auto UserIt = MinBWs.find(TEUseEI.UserTE);
+    if (UserIt != MinBWs.end())
+      TEBitWidth = UserIt->second.second;
+    return BitWidth == TEBitWidth;
+  };
   SmallVector<SmallPtrSet<const TreeEntry *, 4>> UsedTEs;
   DenseMap<Value *, int> UsedValuesEntry;
   for (Value *V : VL) {
@@ -10026,6 +10138,8 @@ BoUpSLP::isGatherShuffledSingleRegisterEntry(
           continue;
       }
 
+      if (!CheckBitwidth(*TEPtr))
+        continue;
       // Check if the user node of the TE comes after user node of TEPtr,
       // otherwise TEPtr depends on TE.
       if ((TEInsertBlock != InsertPt->getParent() ||
@@ -10042,8 +10156,8 @@ BoUpSLP::isGatherShuffledSingleRegisterEntry(
             continue;
           VTE = *It->getSecond().begin();
           // Iterate through all vectorized nodes.
-          auto *MIt = find_if(It->getSecond(), [](const TreeEntry *MTE) {
-            return MTE->State == TreeEntry::Vectorize;
+          auto *MIt = find_if(It->getSecond(), [&](const TreeEntry *MTE) {
+            return MTE->State == TreeEntry::Vectorize && CheckBitwidth(*MTE);
           });
           if (MIt == It->getSecond().end())
             continue;
@@ -10053,10 +10167,7 @@ BoUpSLP::isGatherShuffledSingleRegisterEntry(
       Instruction &LastBundleInst = getLastInstructionInBundle(VTE);
       if (&LastBundleInst == TEInsertPt || !CheckOrdering(&LastBundleInst))
         continue;
-      auto It = MinBWs.find(VTE);
-      // If vectorize node is demoted - do not match.
-      if (It != MinBWs.end() &&
-          It->second.first != DL->getTypeSizeInBits(V->getType()))
+      if (!CheckBitwidth(*VTE))
         continue;
       VToTEs.insert(VTE);
     }
@@ -12929,7 +13040,21 @@ Value *BoUpSLP::vectorizeTree(
   Builder.ClearInsertionPoint();
   InstrElementSize.clear();
 
-  return VectorizableTree[0]->VectorizedValue;
+  const TreeEntry &RootTE = *VectorizableTree.front().get();
+  Value *Vec = RootTE.VectorizedValue;
+  if (auto It = MinBWs.find(&RootTE); ReductionBitWidth != 0 &&
+                                      It != MinBWs.end() &&
+                                      ReductionBitWidth != It->second.first) {
+    IRBuilder<>::InsertPointGuard Guard(Builder);
+    Builder.SetInsertPoint(ReductionRoot->getParent(),
+                           ReductionRoot->getIterator());
+    Vec = Builder.CreateIntCast(
+        Vec,
+        VectorType::get(Builder.getIntNTy(ReductionBitWidth),
+                        cast<VectorType>(Vec->getType())->getElementCount()),
+        It->second.second);
+  }
+  return Vec;
 }
 
 void BoUpSLP::optimizeGatherSequence() {
@@ -13749,23 +13874,42 @@ unsigned BoUpSLP::getVectorElementSize(Value *V) {
 // smaller type with a truncation. We collect the values that will be demoted
 // in ToDemote and additional roots that require investigating in Roots.
 bool BoUpSLP::collectValuesToDemote(
-    Value *V, SmallVectorImpl<Value *> &ToDemote,
+    Value *V, bool IsProfitableToDemoteRoot, unsigned &BitWidth,
+    SmallVectorImpl<Value *> &ToDemote,
     DenseMap<Instruction *, SmallVector<unsigned>> &DemotedConsts,
-    SmallVectorImpl<Value *> &Roots, DenseSet<Value *> &Visited) const {
+    DenseSet<Value *> &Visited, unsigned &MaxDepthLevel,
+    bool &IsProfitableToDemote) const {
   // We can always demote constants.
-  if (isa<Constant>(V))
+  if (isa<Constant>(V)) {
+    MaxDepthLevel = 1;
     return true;
+  }
 
   // If the value is not a vectorized instruction in the expression and not used
   // by the insertelement instruction and not used in multiple vector nodes, it
   // cannot be demoted.
+  // TODO: improve handling of gathered values and others.
   auto *I = dyn_cast<Instruction>(V);
-  if (!I || !getTreeEntry(I) || MultiNodeScalars.contains(I) ||
-      !Visited.insert(I).second || all_of(I->users(), [&](User *U) {
+  if (!I || !Visited.insert(I).second || !getTreeEntry(I) ||
+      MultiNodeScalars.contains(I) || all_of(I->users(), [&](User *U) {
         return isa<InsertElementInst>(U) && !getTreeEntry(U);
       }))
     return false;
 
+  auto IsPotentiallyTruncated = [&](Value *V, unsigned &BitWidth) -> bool {
+    if (MultiNodeScalars.contains(V))
+      return false;
+    uint32_t OrigBitWidth = DL->getTypeSizeInBits(V->getType());
+    APInt Mask = APInt::getBitsSetFrom(OrigBitWidth, BitWidth);
+    if (MaskedValueIsZero(V, Mask, SimplifyQuery(*DL)))
+      return true;
+    auto NumSignBits = ComputeNumSignBits(V, *DL, 0, AC, nullptr, DT);
+   ...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/84334


More information about the llvm-commits mailing list