[llvm] 938a734 - [SLP][NFC]Walk over entries, not single values.

Alexey Bataev via llvm-commits llvm-commits at lists.llvm.org
Wed Apr 10 06:28:57 PDT 2024


Author: Alexey Bataev
Date: 2024-04-10T06:03:26-07:00
New Revision: 938a73422e0b964eba16f272acdfae1d0281772c

URL: https://github.com/llvm/llvm-project/commit/938a73422e0b964eba16f272acdfae1d0281772c
DIFF: https://github.com/llvm/llvm-project/commit/938a73422e0b964eba16f272acdfae1d0281772c.diff

LOG: [SLP][NFC]Walk over entries, not single values.

Better to walk over SLP nodes rather than single values. Matching
a value to a node is not a 1-to-1 relation, one value may be part of
several nodes and compiler may get wrong node, when trying to map it.
Currently there are no such issues detected, but they may appear in
future.

Added: 
    

Modified: 
    llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index c3dcf73b0b7626..22ef9b5fb994e3 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -2325,19 +2325,17 @@ class BoUpSLP {
   ~BoUpSLP();
 
 private:
-  /// Determine if a vectorized value \p V in can be demoted to
-  /// a smaller type with a truncation. We collect the values that will be
-  /// demoted in ToDemote and additional roots that require investigating in
-  /// Roots.
-  /// \param DemotedConsts list of Instruction/OperandIndex pairs that are
-  /// constant and to be demoted. Required to correctly identify constant nodes
-  /// to be demoted.
-  bool collectValuesToDemote(
-      Value *V, bool IsProfitableToDemoteRoot, unsigned &BitWidth,
-      SmallVectorImpl<Value *> &ToDemote,
-      DenseMap<Instruction *, SmallVector<unsigned>> &DemotedConsts,
-      DenseSet<Value *> &Visited, unsigned &MaxDepthLevel,
-      bool &IsProfitableToDemote, bool IsTruncRoot) const;
+  /// Determine if a node \p E in can be demoted to a smaller type with a
+  /// truncation. We collect the entries that will be demoted in ToDemote.
+  /// \param E Node for analysis
+  /// \param ToDemote indices of the nodes to be demoted.
+  bool collectValuesToDemote(const TreeEntry &E, bool IsProfitableToDemoteRoot,
+                             unsigned &BitWidth,
+                             SmallVectorImpl<unsigned> &ToDemote,
+                             DenseSet<const TreeEntry *> &Visited,
+                             unsigned &MaxDepthLevel,
+                             bool &IsProfitableToDemote,
+                             bool IsTruncRoot) const;
 
   /// Check if the operands on the edges \p Edges of the \p UserTE allows
   /// reordering (i.e. the operands can be reordered because they have only one
@@ -14126,20 +14124,17 @@ unsigned BoUpSLP::getVectorElementSize(Value *V) {
   return Width;
 }
 
-// Determine if a value V in a vectorizable expression Expr can be demoted to a
-// smaller type with a truncation. We collect the values that will be demoted
-// in ToDemote and additional roots that require investigating in Roots.
 bool BoUpSLP::collectValuesToDemote(
-    Value *V, bool IsProfitableToDemoteRoot, unsigned &BitWidth,
-    SmallVectorImpl<Value *> &ToDemote,
-    DenseMap<Instruction *, SmallVector<unsigned>> &DemotedConsts,
-    DenseSet<Value *> &Visited, unsigned &MaxDepthLevel,
-    bool &IsProfitableToDemote, bool IsTruncRoot) const {
+    const TreeEntry &E, bool IsProfitableToDemoteRoot, unsigned &BitWidth,
+    SmallVectorImpl<unsigned> &ToDemote, DenseSet<const TreeEntry *> &Visited,
+    unsigned &MaxDepthLevel, bool &IsProfitableToDemote,
+    bool IsTruncRoot) const {
   // We can always demote constants.
-  if (isa<Constant>(V))
+  if (all_of(E.Scalars, IsaPred<Constant>))
     return true;
 
-  if (DL->getTypeSizeInBits(V->getType()) == BitWidth) {
+  unsigned OrigBitWidth = DL->getTypeSizeInBits(E.Scalars.front()->getType());
+  if (OrigBitWidth == BitWidth) {
     MaxDepthLevel = 1;
     return true;
   }
@@ -14150,7 +14145,6 @@ bool BoUpSLP::collectValuesToDemote(
   auto IsPotentiallyTruncated = [&](Value *V, unsigned &BitWidth) -> bool {
     if (MultiNodeScalars.contains(V))
       return false;
-    uint32_t OrigBitWidth = DL->getTypeSizeInBits(V->getType());
     if (OrigBitWidth > BitWidth) {
       APInt Mask = APInt::getBitsSetFrom(OrigBitWidth, BitWidth);
       if (MaskedValueIsZero(V, Mask, SimplifyQuery(*DL)))
@@ -14168,47 +14162,50 @@ bool BoUpSLP::collectValuesToDemote(
     BitWidth = std::max(BitWidth, BitWidth1);
     return BitWidth > 0 && OrigBitWidth >= (BitWidth * 2);
   };
-  auto FinalAnalysis = [&](const TreeEntry *ITE = nullptr) {
+  using namespace std::placeholders;
+  auto FinalAnalysis = [&]() {
     if (!IsProfitableToDemote)
       return false;
-    return (ITE && ITE->UserTreeIndices.size() > 1) ||
-           IsPotentiallyTruncated(V, BitWidth);
+    bool Res = all_of(
+        E.Scalars, std::bind(IsPotentiallyTruncated, _1, std::ref(BitWidth)));
+    // Gather demoted constant operands.
+    if (Res && E.State == TreeEntry::NeedToGather &&
+        all_of(E.Scalars, IsaPred<Constant>))
+      ToDemote.push_back(E.Idx);
+    return Res;
   };
   // TODO: improve handling of gathered values and others.
-  auto *I = dyn_cast<Instruction>(V);
-  const TreeEntry *ITE = I ? getTreeEntry(I) : nullptr;
-  if (!ITE || !Visited.insert(I).second || MultiNodeScalars.contains(I) ||
-      all_of(I->users(), [&](User *U) {
-        return isa<InsertElementInst>(U) && !getTreeEntry(U);
+  if (E.State == TreeEntry::NeedToGather || !Visited.insert(&E).second ||
+      any_of(E.Scalars, [&](Value *V) {
+        return all_of(V->users(), [&](User *U) {
+          return isa<InsertElementInst>(U) && !getTreeEntry(U);
+        });
       }))
     return FinalAnalysis();
 
-  if (!all_of(I->users(),
-              [=](User *U) {
-                return getTreeEntry(U) ||
-                       (UserIgnoreList && UserIgnoreList->contains(U)) ||
-                       (U->getType()->isSized() &&
-                        !U->getType()->isScalableTy() &&
-                        DL->getTypeSizeInBits(U->getType()) <= BitWidth);
-              }) &&
-      !IsPotentiallyTruncated(I, BitWidth))
+  if (any_of(E.Scalars, [&](Value *V) {
+        return !all_of(V->users(), [=](User *U) {
+          return getTreeEntry(U) ||
+                 (UserIgnoreList && UserIgnoreList->contains(U)) ||
+                 (U->getType()->isSized() && !U->getType()->isScalableTy() &&
+                  DL->getTypeSizeInBits(U->getType()) <= BitWidth);
+        }) && !IsPotentiallyTruncated(V, BitWidth);
+      }))
     return false;
 
-  unsigned Start = 0;
-  unsigned End = I->getNumOperands();
-
-  auto ProcessOperands = [&](ArrayRef<Value *> Operands, bool &NeedToExit) {
+  auto ProcessOperands = [&](ArrayRef<const TreeEntry *> Operands,
+                             bool &NeedToExit) {
     NeedToExit = false;
     unsigned InitLevel = MaxDepthLevel;
-    for (Value *IncValue : Operands) {
+    for (const TreeEntry *Op : Operands) {
       unsigned Level = InitLevel;
-      if (!collectValuesToDemote(IncValue, IsProfitableToDemoteRoot, BitWidth,
-                                 ToDemote, DemotedConsts, Visited, Level,
-                                 IsProfitableToDemote, IsTruncRoot)) {
+      if (!collectValuesToDemote(*Op, IsProfitableToDemoteRoot, BitWidth,
+                                 ToDemote, Visited, Level, IsProfitableToDemote,
+                                 IsTruncRoot)) {
         if (!IsProfitableToDemote)
           return false;
         NeedToExit = true;
-        if (!FinalAnalysis(ITE))
+        if (!FinalAnalysis())
           return false;
         continue;
       }
@@ -14220,7 +14217,6 @@ bool BoUpSLP::collectValuesToDemote(
       [&](function_ref<bool(unsigned, unsigned)> Checker, bool &NeedToExit) {
         // Try all bitwidth < OrigBitWidth.
         NeedToExit = false;
-        uint32_t OrigBitWidth = DL->getTypeSizeInBits(I->getType());
         unsigned BestFailBitwidth = 0;
         for (; BitWidth < OrigBitWidth; BitWidth *= 2) {
           if (Checker(BitWidth, OrigBitWidth))
@@ -14241,18 +14237,20 @@ bool BoUpSLP::collectValuesToDemote(
         return false;
       };
   auto TryProcessInstruction =
-      [&](Instruction *I, const TreeEntry &ITE, unsigned &BitWidth,
-          ArrayRef<Value *> Operands = std::nullopt,
+      [&](unsigned &BitWidth,
+          ArrayRef<const TreeEntry *> Operands = std::nullopt,
           function_ref<bool(unsigned, unsigned)> Checker = {}) {
         if (Operands.empty()) {
           if (!IsTruncRoot)
             MaxDepthLevel = 1;
-          (void)IsPotentiallyTruncated(V, BitWidth);
+          (void)for_each(E.Scalars, std::bind(IsPotentiallyTruncated, _1,
+                                              std::ref(BitWidth)));
         } else {
           // Several vectorized uses? Check if we can truncate it, otherwise -
           // exit.
-          if (ITE.UserTreeIndices.size() > 1 &&
-              !IsPotentiallyTruncated(I, BitWidth))
+          if (E.UserTreeIndices.size() > 1 &&
+              !all_of(E.Scalars, std::bind(IsPotentiallyTruncated, _1,
+                                           std::ref(BitWidth))))
             return false;
           bool NeedToExit = false;
           if (Checker && !AttemptCheckBitwidth(Checker, NeedToExit))
@@ -14266,26 +14264,22 @@ bool BoUpSLP::collectValuesToDemote(
         }
 
         ++MaxDepthLevel;
-        // Gather demoted constant operands.
-        for (unsigned Idx : seq<unsigned>(Start, End))
-          if (isa<Constant>(I->getOperand(Idx)))
-            DemotedConsts.try_emplace(I).first->getSecond().push_back(Idx);
-        // Record the value that we can demote.
-        ToDemote.push_back(V);
+        // Record the entry that we can demote.
+        ToDemote.push_back(E.Idx);
         return IsProfitableToDemote;
       };
-  switch (I->getOpcode()) {
+  switch (E.getOpcode()) {
 
   // We can always demote truncations and extensions. Since truncations can
   // seed additional demotion, we save the truncated value.
   case Instruction::Trunc:
     if (IsProfitableToDemoteRoot)
       IsProfitableToDemote = true;
-    return TryProcessInstruction(I, *ITE, BitWidth);
+    return TryProcessInstruction(BitWidth);
   case Instruction::ZExt:
   case Instruction::SExt:
     IsProfitableToDemote = true;
-    return TryProcessInstruction(I, *ITE, BitWidth);
+    return TryProcessInstruction(BitWidth);
 
   // We can demote certain binary operations if we can demote both of their
   // operands.
@@ -14295,112 +14289,128 @@ bool BoUpSLP::collectValuesToDemote(
   case Instruction::And:
   case Instruction::Or:
   case Instruction::Xor: {
-    return TryProcessInstruction(I, *ITE, BitWidth,
-                                 {I->getOperand(0), I->getOperand(1)});
+    return TryProcessInstruction(
+        BitWidth, {getOperandEntry(&E, 0), getOperandEntry(&E, 1)});
   }
   case Instruction::Shl: {
     // If we are truncating the result of this SHL, and if it's a shift of an
     // inrange amount, we can always perform a SHL in a smaller type.
     auto ShlChecker = [&](unsigned BitWidth, unsigned) {
-      KnownBits AmtKnownBits = computeKnownBits(I->getOperand(1), *DL);
-      return AmtKnownBits.getMaxValue().ult(BitWidth);
+      return all_of(E.Scalars, [&](Value *V) {
+        auto *I = cast<Instruction>(V);
+        KnownBits AmtKnownBits = computeKnownBits(I->getOperand(1), *DL);
+        return AmtKnownBits.getMaxValue().ult(BitWidth);
+      });
     };
     return TryProcessInstruction(
-        I, *ITE, BitWidth, {I->getOperand(0), I->getOperand(1)}, ShlChecker);
+        BitWidth, {getOperandEntry(&E, 0), getOperandEntry(&E, 1)}, ShlChecker);
   }
   case Instruction::LShr: {
     // If this is a truncate of a logical shr, we can truncate it to a smaller
     // lshr iff we know that the bits we would otherwise be shifting in are
     // already zeros.
     auto LShrChecker = [&](unsigned BitWidth, unsigned OrigBitWidth) {
-      KnownBits AmtKnownBits = computeKnownBits(I->getOperand(1), *DL);
-      APInt ShiftedBits = APInt::getBitsSetFrom(OrigBitWidth, BitWidth);
-      return AmtKnownBits.getMaxValue().ult(BitWidth) &&
-             MaskedValueIsZero(I->getOperand(0), ShiftedBits,
-                               SimplifyQuery(*DL));
+      return all_of(E.Scalars, [&](Value *V) {
+        auto *I = cast<Instruction>(V);
+        KnownBits AmtKnownBits = computeKnownBits(I->getOperand(1), *DL);
+        APInt ShiftedBits = APInt::getBitsSetFrom(OrigBitWidth, BitWidth);
+        return AmtKnownBits.getMaxValue().ult(BitWidth) &&
+               MaskedValueIsZero(I->getOperand(0), ShiftedBits,
+                                 SimplifyQuery(*DL));
+      });
     };
     return TryProcessInstruction(
-        I, *ITE, BitWidth, {I->getOperand(0), I->getOperand(1)}, LShrChecker);
+        BitWidth, {getOperandEntry(&E, 0), getOperandEntry(&E, 1)},
+        LShrChecker);
   }
   case Instruction::AShr: {
     // If this is a truncate of an arithmetic shr, we can truncate it to a
     // smaller ashr iff we know that all the bits from the sign bit of the
     // original type and the sign bit of the truncate type are similar.
     auto AShrChecker = [&](unsigned BitWidth, unsigned OrigBitWidth) {
-      KnownBits AmtKnownBits = computeKnownBits(I->getOperand(1), *DL);
-      unsigned ShiftedBits = OrigBitWidth - BitWidth;
-      return AmtKnownBits.getMaxValue().ult(BitWidth) &&
-             ShiftedBits <
-                 ComputeNumSignBits(I->getOperand(0), *DL, 0, AC, nullptr, DT);
+      return all_of(E.Scalars, [&](Value *V) {
+        auto *I = cast<Instruction>(V);
+        KnownBits AmtKnownBits = computeKnownBits(I->getOperand(1), *DL);
+        unsigned ShiftedBits = OrigBitWidth - BitWidth;
+        return AmtKnownBits.getMaxValue().ult(BitWidth) &&
+               ShiftedBits < ComputeNumSignBits(I->getOperand(0), *DL, 0, AC,
+                                                nullptr, DT);
+      });
     };
     return TryProcessInstruction(
-        I, *ITE, BitWidth, {I->getOperand(0), I->getOperand(1)}, AShrChecker);
+        BitWidth, {getOperandEntry(&E, 0), getOperandEntry(&E, 1)},
+        AShrChecker);
   }
   case Instruction::UDiv:
   case Instruction::URem: {
     // UDiv and URem can be truncated if all the truncated bits are zero.
     auto Checker = [&](unsigned BitWidth, unsigned OrigBitWidth) {
       assert(BitWidth <= OrigBitWidth && "Unexpected bitwidths!");
-      APInt Mask = APInt::getBitsSetFrom(OrigBitWidth, BitWidth);
-      return MaskedValueIsZero(I->getOperand(0), Mask, SimplifyQuery(*DL)) &&
-             MaskedValueIsZero(I->getOperand(1), Mask, SimplifyQuery(*DL));
+      return all_of(E.Scalars, [&](Value *V) {
+        auto *I = cast<Instruction>(V);
+        APInt Mask = APInt::getBitsSetFrom(OrigBitWidth, BitWidth);
+        return MaskedValueIsZero(I->getOperand(0), Mask, SimplifyQuery(*DL)) &&
+               MaskedValueIsZero(I->getOperand(1), Mask, SimplifyQuery(*DL));
+      });
     };
-    return TryProcessInstruction(I, *ITE, BitWidth,
-                                 {I->getOperand(0), I->getOperand(1)}, Checker);
+    return TryProcessInstruction(
+        BitWidth, {getOperandEntry(&E, 0), getOperandEntry(&E, 1)}, Checker);
   }
 
   // We can demote selects if we can demote their true and false values.
   case Instruction::Select: {
-    Start = 1;
-    auto *SI = cast<SelectInst>(I);
-    return TryProcessInstruction(I, *ITE, BitWidth,
-                                 {SI->getTrueValue(), SI->getFalseValue()});
+    return TryProcessInstruction(
+        BitWidth, {getOperandEntry(&E, 1), getOperandEntry(&E, 2)});
   }
 
   // We can demote phis if we can demote all their incoming operands. Note that
   // we don't need to worry about cycles since we ensure single use above.
   case Instruction::PHI: {
-    PHINode *PN = cast<PHINode>(I);
-    SmallVector<Value *> Ops(PN->incoming_values().begin(),
-                             PN->incoming_values().end());
-    return TryProcessInstruction(I, *ITE, BitWidth, Ops);
+    const unsigned NumOps = E.getNumOperands();
+    SmallVector<const TreeEntry *> Ops(NumOps);
+    transform(seq<unsigned>(0, NumOps), Ops.begin(),
+              std::bind(&BoUpSLP::getOperandEntry, this, &E, _1));
+
+    return TryProcessInstruction(BitWidth, Ops);
   }
 
   case Instruction::Call: {
-    auto *IC = dyn_cast<IntrinsicInst>(I);
+    auto *IC = dyn_cast<IntrinsicInst>(E.getMainOp());
     if (!IC)
       break;
     Intrinsic::ID ID = getVectorIntrinsicIDForCall(IC, TLI);
     if (ID != Intrinsic::abs && ID != Intrinsic::smin &&
         ID != Intrinsic::smax && ID != Intrinsic::umin && ID != Intrinsic::umax)
       break;
-    SmallVector<Value *> Operands(1, I->getOperand(0));
+    SmallVector<const TreeEntry *, 2> Operands(1, getOperandEntry(&E, 0));
     function_ref<bool(unsigned, unsigned)> CallChecker;
     auto CompChecker = [&](unsigned BitWidth, unsigned OrigBitWidth) {
       assert(BitWidth <= OrigBitWidth && "Unexpected bitwidths!");
-      if (ID == Intrinsic::umin || ID == Intrinsic::umax) {
-        APInt Mask = APInt::getBitsSetFrom(OrigBitWidth, BitWidth);
-        return MaskedValueIsZero(I->getOperand(0), Mask, SimplifyQuery(*DL)) &&
-               MaskedValueIsZero(I->getOperand(1), Mask, SimplifyQuery(*DL));
-      }
-      assert((ID == Intrinsic::smin || ID == Intrinsic::smax) &&
-             "Expected min/max intrinsics only.");
-      unsigned SignBits = OrigBitWidth - BitWidth;
-      return SignBits <= ComputeNumSignBits(I->getOperand(0), *DL, 0, AC,
-                                            nullptr, DT) &&
-             SignBits <=
-                 ComputeNumSignBits(I->getOperand(1), *DL, 0, AC, nullptr, DT);
+      return all_of(E.Scalars, [&](Value *V) {
+        auto *I = cast<Instruction>(V);
+        if (ID == Intrinsic::umin || ID == Intrinsic::umax) {
+          APInt Mask = APInt::getBitsSetFrom(OrigBitWidth, BitWidth);
+          return MaskedValueIsZero(I->getOperand(0), Mask,
+                                   SimplifyQuery(*DL)) &&
+                 MaskedValueIsZero(I->getOperand(1), Mask, SimplifyQuery(*DL));
+        }
+        assert((ID == Intrinsic::smin || ID == Intrinsic::smax) &&
+               "Expected min/max intrinsics only.");
+        unsigned SignBits = OrigBitWidth - BitWidth;
+        return SignBits <= ComputeNumSignBits(I->getOperand(0), *DL, 0, AC,
+                                              nullptr, DT) &&
+               SignBits <= ComputeNumSignBits(I->getOperand(1), *DL, 0, AC,
+                                              nullptr, DT);
+      });
     };
-    End = 1;
     if (ID != Intrinsic::abs) {
-      Operands.push_back(I->getOperand(1));
-      End = 2;
+      Operands.push_back(getOperandEntry(&E, 1));
       CallChecker = CompChecker;
     }
     InstructionCost BestCost =
         std::numeric_limits<InstructionCost::CostType>::max();
     unsigned BestBitWidth = BitWidth;
-    unsigned VF = ITE->Scalars.size();
+    unsigned VF = E.Scalars.size();
     // Choose the best bitwidth based on cost estimations.
     auto Checker = [&](unsigned BitWidth, unsigned) {
       unsigned MinBW = PowerOf2Ceil(BitWidth);
@@ -14419,7 +14429,7 @@ bool BoUpSLP::collectValuesToDemote(
     [[maybe_unused]] bool NeedToExit;
     (void)AttemptCheckBitwidth(Checker, NeedToExit);
     BitWidth = BestBitWidth;
-    return TryProcessInstruction(I, *ITE, BitWidth, Operands, CallChecker);
+    return TryProcessInstruction(BitWidth, Operands, CallChecker);
   }
 
   // Otherwise, conservatively give up.
@@ -14473,26 +14483,27 @@ void BoUpSLP::computeMinimumValueSizes() {
     ++NodeIdx;
   }
 
-  // Analyzed in reduction already and not profitable - exit.
+  // Analyzed the reduction already and not profitable - exit.
   if (AnalyzedMinBWVals.contains(VectorizableTree[NodeIdx]->Scalars.front()))
     return;
 
-  SmallVector<Value *> ToDemote;
-  DenseMap<Instruction *, SmallVector<unsigned>> DemotedConsts;
-  auto ComputeMaxBitWidth = [&](ArrayRef<Value *> TreeRoot, unsigned VF,
-                                bool IsTopRoot, bool IsProfitableToDemoteRoot,
-                                unsigned Opcode, unsigned Limit,
-                                bool IsTruncRoot, bool IsSignedCmp) {
+  SmallVector<unsigned> ToDemote;
+  auto ComputeMaxBitWidth = [&](const TreeEntry &E, bool IsTopRoot,
+                                bool IsProfitableToDemoteRoot, unsigned Opcode,
+                                unsigned Limit, bool IsTruncRoot,
+                                bool IsSignedCmp) {
     ToDemote.clear();
-    auto *TreeRootIT = dyn_cast<IntegerType>(TreeRoot[0]->getType());
+    unsigned VF = E.getVectorFactor();
+    auto *TreeRootIT = dyn_cast<IntegerType>(E.Scalars.front()->getType());
     if (!TreeRootIT || !Opcode)
       return 0u;
 
-    if (AnalyzedMinBWVals.contains(TreeRoot.front()))
+    if (any_of(E.Scalars,
+               [&](Value *V) { return AnalyzedMinBWVals.contains(V); }))
       return 0u;
 
-    unsigned NumParts = TTI->getNumberOfParts(
-        FixedVectorType::get(TreeRoot.front()->getType(), VF));
+    unsigned NumParts =
+        TTI->getNumberOfParts(FixedVectorType::get(TreeRootIT, VF));
 
     // The maximum bit width required to represent all the values that can be
     // demoted without loss of precision. It would be safe to truncate the roots
@@ -14505,14 +14516,14 @@ void BoUpSLP::computeMinimumValueSizes() {
     // True.
     // Determine if the sign bit of all the roots is known to be zero. If not,
     // IsKnownPositive is set to False.
-    bool IsKnownPositive = !IsSignedCmp && all_of(TreeRoot, [&](Value *R) {
+    bool IsKnownPositive = !IsSignedCmp && all_of(E.Scalars, [&](Value *R) {
       KnownBits Known = computeKnownBits(R, *DL);
       return Known.isNonNegative();
     });
 
     // We first check if all the bits of the roots are demanded. If they're not,
     // we can truncate the roots to this narrower type.
-    for (auto *Root : TreeRoot) {
+    for (Value *Root : E.Scalars) {
       unsigned NumSignBits = ComputeNumSignBits(Root, *DL, 0, AC, nullptr, DT);
       TypeSize NumTypeBits = DL->getTypeSizeInBits(Root->getType());
       unsigned BitWidth1 = NumTypeBits - NumSignBits;
@@ -14557,23 +14568,22 @@ void BoUpSLP::computeMinimumValueSizes() {
     // Conservatively determine if we can actually truncate the roots of the
     // expression. Collect the values that can be demoted in ToDemote and
     // additional roots that require investigating in Roots.
-    for (auto *Root : TreeRoot) {
-      DenseSet<Value *> Visited;
-      unsigned MaxDepthLevel = IsTruncRoot ? Limit : 1;
-      bool NeedToDemote = IsProfitableToDemote;
-
-      if (!collectValuesToDemote(Root, IsProfitableToDemoteRoot, MaxBitWidth,
-                                 ToDemote, DemotedConsts, Visited,
-                                 MaxDepthLevel, NeedToDemote, IsTruncRoot) ||
-          (MaxDepthLevel <= Limit &&
-           !(((Opcode == Instruction::SExt || Opcode == Instruction::ZExt) &&
-              (!IsTopRoot || !(IsStoreOrInsertElt || UserIgnoreList) ||
-               DL->getTypeSizeInBits(Root->getType()) /
-                       DL->getTypeSizeInBits(
-                           cast<Instruction>(Root)->getOperand(0)->getType()) >
-                   2)))))
-        return 0u;
-    }
+    DenseSet<const TreeEntry *> Visited;
+    unsigned MaxDepthLevel = IsTruncRoot ? Limit : 1;
+    bool NeedToDemote = IsProfitableToDemote;
+
+    if (!collectValuesToDemote(E, IsProfitableToDemoteRoot, MaxBitWidth,
+                               ToDemote, Visited, MaxDepthLevel, NeedToDemote,
+                               IsTruncRoot) ||
+        (MaxDepthLevel <= Limit &&
+         !(((Opcode == Instruction::SExt || Opcode == Instruction::ZExt) &&
+            (!IsTopRoot || !(IsStoreOrInsertElt || UserIgnoreList) ||
+             DL->getTypeSizeInBits(TreeRootIT) /
+                     DL->getTypeSizeInBits(cast<Instruction>(E.Scalars.front())
+                                               ->getOperand(0)
+                                               ->getType()) >
+                 2)))))
+      return 0u;
     // Round MaxBitWidth up to the next power-of-two.
     MaxBitWidth = bit_ceil(MaxBitWidth);
 
@@ -14624,8 +14634,8 @@ void BoUpSLP::computeMinimumValueSizes() {
                 VectorizableTree.front()->Scalars.front()->getType()))
       Limit = 3;
     unsigned MaxBitWidth = ComputeMaxBitWidth(
-        TreeRoot, VectorizableTree[NodeIdx]->getVectorFactor(), IsTopRoot,
-        IsProfitableToDemoteRoot, Opcode, Limit, IsTruncRoot, IsSignedCmp);
+        *VectorizableTree[NodeIdx].get(), IsTopRoot, IsProfitableToDemoteRoot,
+        Opcode, Limit, IsTruncRoot, IsSignedCmp);
     if (ReductionBitWidth != 0 && (IsTopRoot || !RootDemotes.empty())) {
       if (MaxBitWidth != 0 && ReductionBitWidth < MaxBitWidth)
         ReductionBitWidth = bit_ceil(MaxBitWidth);
@@ -14634,13 +14644,15 @@ void BoUpSLP::computeMinimumValueSizes() {
     }
 
     for (unsigned Idx : RootDemotes) {
-      Value *V = VectorizableTree[Idx]->Scalars.front();
-      uint32_t OrigBitWidth = DL->getTypeSizeInBits(V->getType());
-      if (OrigBitWidth > MaxBitWidth) {
-      APInt Mask = APInt::getBitsSetFrom(OrigBitWidth, MaxBitWidth);
-      if (MaskedValueIsZero(V, Mask, SimplifyQuery(*DL)))
-        ToDemote.push_back(V);
-      }
+      if (all_of(VectorizableTree[Idx]->Scalars, [&](Value *V) {
+            uint32_t OrigBitWidth = DL->getTypeSizeInBits(V->getType());
+            if (OrigBitWidth > MaxBitWidth) {
+              APInt Mask = APInt::getBitsSetFrom(OrigBitWidth, MaxBitWidth);
+              return MaskedValueIsZero(V, Mask, SimplifyQuery(*DL));
+            }
+            return false;
+          }))
+        ToDemote.push_back(Idx);
     }
     RootDemotes.clear();
     IsTopRoot = false;
@@ -14687,9 +14699,8 @@ void BoUpSLP::computeMinimumValueSizes() {
 
     // Finally, map the values we can demote to the maximum bit with we
     // computed.
-    for (Value *Scalar : ToDemote) {
-      TreeEntry *TE = getTreeEntry(Scalar);
-      assert(TE && "Expected vectorized scalar.");
+    for (unsigned Idx : ToDemote) {
+      TreeEntry *TE = VectorizableTree[Idx].get();
       if (MinBWs.contains(TE))
         continue;
       bool IsSigned = TE->getOpcode() == Instruction::SExt ||
@@ -14697,22 +14708,6 @@ void BoUpSLP::computeMinimumValueSizes() {
                         return !isKnownNonNegative(R, SimplifyQuery(*DL));
                       });
       MinBWs.try_emplace(TE, MaxBitWidth, IsSigned);
-      const auto *I = cast<Instruction>(Scalar);
-      auto DCIt = DemotedConsts.find(I);
-      if (DCIt != DemotedConsts.end()) {
-        for (unsigned Idx : DCIt->getSecond()) {
-          // Check that all instructions operands are demoted.
-          const TreeEntry *CTE = getOperandEntry(TE, Idx);
-          if (all_of(TE->Scalars,
-                     [&](Value *V) {
-                       auto SIt = DemotedConsts.find(cast<Instruction>(V));
-                       return SIt != DemotedConsts.end() &&
-                              is_contained(SIt->getSecond(), Idx);
-                     }) ||
-              all_of(CTE->Scalars, IsaPred<Constant>))
-            MinBWs.try_emplace(CTE, MaxBitWidth, IsSigned);
-        }
-      }
     }
   }
 }


        


More information about the llvm-commits mailing list