[llvm] [SLP]Improved reduction cost/codegen (PR #118293)

Alexey Bataev via llvm-commits llvm-commits at lists.llvm.org
Tue Dec 24 13:05:33 PST 2024


https://github.com/alexey-bataev updated https://github.com/llvm/llvm-project/pull/118293

>From 2f17bfb2f9814b2a40fa1ced3947d5348e4a9d96 Mon Sep 17 00:00:00 2001
From: Alexey Bataev <a.bataev at outlook.com>
Date: Mon, 2 Dec 2024 13:37:07 +0000
Subject: [PATCH 1/2] =?UTF-8?q?[=F0=9D=98=80=F0=9D=97=BD=F0=9D=97=BF]=20in?=
 =?UTF-8?q?itial=20version?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

Created using spr 1.3.5
---
 .../llvm/Analysis/TargetTransformInfo.h       |   8 +
 .../llvm/Analysis/TargetTransformInfoImpl.h   |   1 +
 llvm/include/llvm/CodeGen/BasicTTIImpl.h      |  16 +
 llvm/lib/Analysis/TargetTransformInfo.cpp     |   4 +
 .../Transforms/Vectorize/SLPVectorizer.cpp    | 322 ++++++++++++++++--
 5 files changed, 319 insertions(+), 32 deletions(-)

diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 985ca1532e0149..f2f0e56a3f2014 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -1584,6 +1584,10 @@ class TargetTransformInfo {
   /// split during legalization. Zero is returned when the answer is unknown.
   unsigned getNumberOfParts(Type *Tp) const;
 
+  /// \return true if \p Tp represent a type, fully occupying whole register,
+  /// false otherwise.
+  bool isFullSingleRegisterType(Type *Tp) const;
+
   /// \returns The cost of the address computation. For most targets this can be
   /// merged into the instruction indexing mode. Some targets might want to
   /// distinguish between address computation for memory operations on vector
@@ -2196,6 +2200,7 @@ class TargetTransformInfo::Concept {
                                            ArrayRef<Type *> Tys,
                                            TTI::TargetCostKind CostKind) = 0;
   virtual unsigned getNumberOfParts(Type *Tp) = 0;
+  virtual bool isFullSingleRegisterType(Type *Tp) const = 0;
   virtual InstructionCost
   getAddressComputationCost(Type *Ty, ScalarEvolution *SE, const SCEV *Ptr) = 0;
   virtual InstructionCost
@@ -2930,6 +2935,9 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
   unsigned getNumberOfParts(Type *Tp) override {
     return Impl.getNumberOfParts(Tp);
   }
+  bool isFullSingleRegisterType(Type *Tp) const override {
+    return Impl.isFullSingleRegisterType(Tp);
+  }
   InstructionCost getAddressComputationCost(Type *Ty, ScalarEvolution *SE,
                                             const SCEV *Ptr) override {
     return Impl.getAddressComputationCost(Ty, SE, Ptr);
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 38aba183f6a173..ce6a96ea317ba7 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -833,6 +833,7 @@ class TargetTransformInfoImplBase {
 
   // Assume that we have a register of the right size for the type.
   unsigned getNumberOfParts(Type *Tp) const { return 1; }
+  bool isFullSingleRegisterType(Type *Tp) const { return false; }
 
   InstructionCost getAddressComputationCost(Type *Tp, ScalarEvolution *,
                                             const SCEV *) const {
diff --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
index 98cbb4886642bf..9e7ce48f901dc5 100644
--- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h
+++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
@@ -2612,6 +2612,22 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
     return *LT.first.getValue();
   }
 
+  bool isFullSingleRegisterType(Type *Tp) const {
+    std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(Tp);
+    if (!LT.first.isValid() || LT.first > 1)
+      return false;
+
+    if (auto *FTp = dyn_cast<FixedVectorType>(Tp);
+        Tp && LT.second.isFixedLengthVector()) {
+      // Check if the n x i1 fits fully into largest integer.
+      if (unsigned VF = LT.second.getVectorNumElements();
+          LT.second.getVectorElementType() == MVT::i1)
+        return DL.isLegalInteger(VF) && !DL.isLegalInteger(VF * 2);
+      return FTp == EVT(LT.second).getTypeForEVT(Tp->getContext());
+    }
+    return false;
+  }
+
   InstructionCost getAddressComputationCost(Type *Ty, ScalarEvolution *,
                                             const SCEV *) {
     return 0;
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index 1fb2b9836de0cc..f7ad9ed905e3a1 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -1171,6 +1171,10 @@ unsigned TargetTransformInfo::getNumberOfParts(Type *Tp) const {
   return TTIImpl->getNumberOfParts(Tp);
 }
 
+bool TargetTransformInfo::isFullSingleRegisterType(Type *Tp) const {
+  return TTIImpl->isFullSingleRegisterType(Tp);
+}
+
 InstructionCost
 TargetTransformInfo::getAddressComputationCost(Type *Tp, ScalarEvolution *SE,
                                                const SCEV *Ptr) const {
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index 7723442bc0fb6e..5df21b77643746 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -12080,7 +12080,11 @@ bool BoUpSLP::isTreeNotExtendable() const {
     TreeEntry &E = *VectorizableTree[Idx];
     if (!E.isGather())
       continue;
-    if (E.getOpcode() && E.getOpcode() != Instruction::Load)
+    if ((E.getOpcode() && E.getOpcode() != Instruction::Load) ||
+        (!E.getOpcode() &&
+         all_of(E.Scalars, IsaPred<ExtractElementInst, LoadInst>)) ||
+        (isa<ExtractElementInst>(E.Scalars.front()) &&
+         getSameOpcode(ArrayRef(E.Scalars).drop_front(), *TLI).getOpcode()))
       return false;
     if (isSplat(E.Scalars) || allConstant(E.Scalars))
       continue;
@@ -19174,6 +19178,9 @@ class HorizontalReduction {
   /// Checks if the optimization of original scalar identity operations on
   /// matched horizontal reductions is enabled and allowed.
   bool IsSupportedHorRdxIdentityOp = false;
+  /// Contains vector values for reduction including their scale factor and
+  /// signedness.
+  SmallVector<std::tuple<Value *, unsigned, bool>> VectorValuesAndScales;
 
   static bool isCmpSelMinMax(Instruction *I) {
     return match(I, m_Select(m_Cmp(), m_Value(), m_Value())) &&
@@ -19225,17 +19232,22 @@ class HorizontalReduction {
   static Value *createOp(IRBuilderBase &Builder, RecurKind Kind, Value *LHS,
                          Value *RHS, const Twine &Name, bool UseSelect) {
     unsigned RdxOpcode = RecurrenceDescriptor::getOpcode(Kind);
+    Type *OpTy = LHS->getType();
+    assert(OpTy == RHS->getType() && "Expected LHS and RHS of same type");
     switch (Kind) {
     case RecurKind::Or:
-      if (UseSelect &&
-          LHS->getType() == CmpInst::makeCmpResultType(LHS->getType()))
-        return Builder.CreateSelect(LHS, Builder.getTrue(), RHS, Name);
+      if (UseSelect && OpTy == CmpInst::makeCmpResultType(OpTy))
+        return Builder.CreateSelect(
+            LHS,
+            ConstantInt::getAllOnesValue(CmpInst::makeCmpResultType(OpTy)),
+            RHS, Name);
       return Builder.CreateBinOp((Instruction::BinaryOps)RdxOpcode, LHS, RHS,
                                  Name);
     case RecurKind::And:
-      if (UseSelect &&
-          LHS->getType() == CmpInst::makeCmpResultType(LHS->getType()))
-        return Builder.CreateSelect(LHS, RHS, Builder.getFalse(), Name);
+      if (UseSelect && OpTy == CmpInst::makeCmpResultType(OpTy))
+        return Builder.CreateSelect(
+            LHS, RHS,
+            ConstantInt::getNullValue(CmpInst::makeCmpResultType(OpTy)), Name);
       return Builder.CreateBinOp((Instruction::BinaryOps)RdxOpcode, LHS, RHS,
                                  Name);
     case RecurKind::Add:
@@ -20108,12 +20120,11 @@ class HorizontalReduction {
                                          SameValuesCounter, TrackedToOrig);
         }
 
-        Value *ReducedSubTree;
         Type *ScalarTy = VL.front()->getType();
         if (isa<FixedVectorType>(ScalarTy)) {
           assert(SLPReVec && "FixedVectorType is not expected.");
           unsigned ScalarTyNumElements = getNumElements(ScalarTy);
-          ReducedSubTree = PoisonValue::get(FixedVectorType::get(
+          Value *ReducedSubTree = PoisonValue::get(getWidenedType(
               VectorizedRoot->getType()->getScalarType(), ScalarTyNumElements));
           for (unsigned I : seq<unsigned>(ScalarTyNumElements)) {
             // Do reduction for each lane.
@@ -20131,30 +20142,32 @@ class HorizontalReduction {
             SmallVector<int, 16> Mask =
                 createStrideMask(I, ScalarTyNumElements, VL.size());
             Value *Lane = Builder.CreateShuffleVector(VectorizedRoot, Mask);
-            ReducedSubTree = Builder.CreateInsertElement(
-                ReducedSubTree,
-                emitReduction(Lane, Builder, TTI, RdxRootInst->getType()), I);
+            Value *Val =
+                createSingleOp(Builder, *TTI, Lane,
+                               OptReusedScalars && SameScaleFactor
+                                   ? SameValuesCounter.front().second
+                                   : 1,
+                               Lane->getType()->getScalarType() !=
+                                       VL.front()->getType()->getScalarType()
+                                   ? V.isSignedMinBitwidthRootNode()
+                                   : true, RdxRootInst->getType());
+            ReducedSubTree =
+                Builder.CreateInsertElement(ReducedSubTree, Val, I);
           }
+          VectorizedTree = GetNewVectorizedTree(VectorizedTree, ReducedSubTree);
         } else {
-          ReducedSubTree = emitReduction(VectorizedRoot, Builder, TTI,
-                                         RdxRootInst->getType());
+          Type *VecTy = VectorizedRoot->getType();
+          Type *RedScalarTy = VecTy->getScalarType();
+          VectorValuesAndScales.emplace_back(
+              VectorizedRoot,
+              OptReusedScalars && SameScaleFactor
+                  ? SameValuesCounter.front().second
+                  : 1,
+              RedScalarTy != ScalarTy->getScalarType()
+                  ? V.isSignedMinBitwidthRootNode()
+                  : true);
         }
-        if (ReducedSubTree->getType() != VL.front()->getType()) {
-          assert(ReducedSubTree->getType() != VL.front()->getType() &&
-                 "Expected different reduction type.");
-          ReducedSubTree =
-              Builder.CreateIntCast(ReducedSubTree, VL.front()->getType(),
-                                    V.isSignedMinBitwidthRootNode());
-        }
-
-        // Improved analysis for add/fadd/xor reductions with same scale factor
-        // for all operands of reductions. We can emit scalar ops for them
-        // instead.
-        if (OptReusedScalars && SameScaleFactor)
-          ReducedSubTree = emitScaleForReusedOps(
-              ReducedSubTree, Builder, SameValuesCounter.front().second);
 
-        VectorizedTree = GetNewVectorizedTree(VectorizedTree, ReducedSubTree);
         // Count vectorized reduced values to exclude them from final reduction.
         for (Value *RdxVal : VL) {
           Value *OrigV = TrackedToOrig.at(RdxVal);
@@ -20183,6 +20196,10 @@ class HorizontalReduction {
         continue;
       }
     }
+    if (!VectorValuesAndScales.empty())
+      VectorizedTree = GetNewVectorizedTree(
+          VectorizedTree,
+          emitReduction(Builder, *TTI, ReductionRoot->getType()));
     if (VectorizedTree) {
       // Reorder operands of bool logical op in the natural order to avoid
       // possible problem with poison propagation. If not possible to reorder
@@ -20317,6 +20334,28 @@ class HorizontalReduction {
   }
 
 private:
+  /// Checks if the given type \p Ty is a vector type, which does not occupy the
+  /// whole vector register or is expensive for extraction.
+  static bool isNotFullVectorType(const TargetTransformInfo &TTI, Type *Ty) {
+    return TTI.getNumberOfParts(Ty) == 1 && !TTI.isFullSingleRegisterType(Ty);
+  }
+
+  /// Creates the reduction from the given \p Vec vector value with the given
+  /// scale \p Scale and signedness \p IsSigned.
+  Value *createSingleOp(IRBuilderBase &Builder, const TargetTransformInfo &TTI,
+                        Value *Vec, unsigned Scale, bool IsSigned,
+                        Type *DestTy) {
+    Value *Rdx = emitReduction(Vec, Builder, &TTI, DestTy);
+    if (Rdx->getType() != DestTy->getScalarType())
+      Rdx = Builder.CreateIntCast(Rdx, DestTy, IsSigned);
+    // Improved analysis for add/fadd/xor reductions with same scale
+    // factor for all operands of reductions. We can emit scalar ops for
+    // them instead.
+    if (Scale > 1)
+      Rdx = emitScaleForReusedOps(Rdx, Builder, Scale);
+    return Rdx;
+  }
+
   /// Calculate the cost of a reduction.
   InstructionCost getReductionCost(TargetTransformInfo *TTI,
                                    ArrayRef<Value *> ReducedVals,
@@ -20359,6 +20398,22 @@ class HorizontalReduction {
       }
       return Cost;
     };
+    // Require reduction cost if:
+    // 1. This type is not a full register type and no other vectors with the
+    // same type in the storage (first vector with small type).
+    // 2. The storage does not have any vector with full vector use (first
+    // vector with full register use).
+    bool DoesRequireReductionOp =
+        !AllConsts &&
+        (VectorValuesAndScales.empty() ||
+         (isNotFullVectorType(*TTI, VectorTy) &&
+          none_of(VectorValuesAndScales,
+                  [&](const auto &P) {
+                    return std::get<0>(P)->getType() == VectorTy;
+                  })) ||
+         all_of(VectorValuesAndScales, [&](const auto &P) {
+           return isNotFullVectorType(*TTI, std::get<0>(P)->getType());
+         }));
     switch (RdxKind) {
     case RecurKind::Add:
     case RecurKind::Mul:
@@ -20382,7 +20437,7 @@ class HorizontalReduction {
           VectorCost += TTI->getScalarizationOverhead(
               VecTy, APInt::getAllOnes(ScalarTyNumElements), /*Insert*/ true,
               /*Extract*/ false, TTI::TCK_RecipThroughput);
-        } else {
+        } else if (DoesRequireReductionOp) {
           Type *RedTy = VectorTy->getElementType();
           auto [RType, IsSigned] = R.getRootNodeTypeWithNoCast().value_or(
               std::make_pair(RedTy, true));
@@ -20394,6 +20449,14 @@ class HorizontalReduction {
                 RdxOpcode, !IsSigned, RedTy, getWidenedType(RType, ReduxWidth),
                 FMF, CostKind);
           }
+        } else {
+          unsigned NumParts = TTI->getNumberOfParts(VectorTy);
+          unsigned RegVF = getPartNumElems(getNumElements(VectorTy), NumParts);
+          VectorCost +=
+              NumParts * TTI->getArithmeticInstrCost(
+                             RdxOpcode,
+                             getWidenedType(VectorTy->getScalarType(), RegVF),
+                             CostKind);
         }
       }
       ScalarCost = EvaluateScalarCost([&]() {
@@ -20410,8 +20473,19 @@ class HorizontalReduction {
     case RecurKind::UMax:
     case RecurKind::UMin: {
       Intrinsic::ID Id = getMinMaxReductionIntrinsicOp(RdxKind);
-      if (!AllConsts)
-        VectorCost = TTI->getMinMaxReductionCost(Id, VectorTy, FMF, CostKind);
+      if (!AllConsts) {
+        if (DoesRequireReductionOp) {
+          VectorCost = TTI->getMinMaxReductionCost(Id, VectorTy, FMF, CostKind);
+        } else {
+          // Check if the previous reduction already exists and account it as
+          // series of operations + single reduction.
+          unsigned NumParts = TTI->getNumberOfParts(VectorTy);
+          unsigned RegVF = getPartNumElems(getNumElements(VectorTy), NumParts);
+          auto *RegVecTy = getWidenedType(VectorTy->getScalarType(), RegVF);
+          IntrinsicCostAttributes ICA(Id, RegVecTy, {RegVecTy, RegVecTy}, FMF);
+          VectorCost += NumParts * TTI->getIntrinsicInstrCost(ICA, CostKind);
+        }
+      }
       ScalarCost = EvaluateScalarCost([&]() {
         IntrinsicCostAttributes ICA(Id, ScalarTy, {ScalarTy, ScalarTy}, FMF);
         return TTI->getIntrinsicInstrCost(ICA, CostKind);
@@ -20428,6 +20502,190 @@ class HorizontalReduction {
     return VectorCost - ScalarCost;
   }
 
+  /// Splits the values, stored in VectorValuesAndScales, into registers/free
+  /// sub-registers, combines them with the given reduction operation as a
+  /// vector operation and then performs single (small enough) reduction.
+  Value *emitReduction(IRBuilderBase &Builder, const TargetTransformInfo &TTI,
+                       Type *DestTy) {
+    Value *ReducedSubTree = nullptr;
+    // Creates reduction and combines with the previous reduction.
+    auto CreateSingleOp = [&](Value *Vec, unsigned Scale, bool IsSigned) {
+      Value *Rdx = createSingleOp(Builder, TTI, Vec, Scale, IsSigned, DestTy);
+      if (ReducedSubTree)
+        ReducedSubTree = createOp(Builder, RdxKind, ReducedSubTree, Rdx,
+                                  "op.rdx", ReductionOps);
+      else
+        ReducedSubTree = Rdx;
+    };
+    if (VectorValuesAndScales.size() == 1) {
+      const auto &[Vec, Scale, IsSigned] = VectorValuesAndScales.front();
+      CreateSingleOp(Vec, Scale, IsSigned);
+      return ReducedSubTree;
+    }
+    // Splits multivector value into per-register values.
+    auto SplitVector = [&](Value *Vec) {
+      auto *ScalarTy = cast<VectorType>(Vec->getType())->getElementType();
+      unsigned Sz = getNumElements(Vec->getType());
+      unsigned NumParts = TTI.getNumberOfParts(Vec->getType());
+      if (NumParts <= 1 || NumParts >= Sz ||
+          isNotFullVectorType(TTI, Vec->getType()))
+        return SmallVector<Value *>(1, Vec);
+      unsigned RegSize = getPartNumElems(Sz, NumParts);
+      auto *DstTy = getWidenedType(ScalarTy, RegSize);
+      SmallVector<Value *> Regs(NumParts);
+      for (unsigned Part : seq<unsigned>(NumParts))
+        Regs[Part] = Builder.CreateExtractVector(
+            DstTy, Vec, Builder.getInt64(Part * RegSize));
+      return Regs;
+    };
+    SmallMapVector<Type *, Value *, 4> VecOps;
+    // Scales Vec using given Cnt scale factor and then performs vector combine
+    // with previous value of VecOp.
+    auto CreateVecOp = [&](Value *Vec, unsigned Cnt) {
+      Type *ScalarTy = cast<VectorType>(Vec->getType())->getElementType();
+      // Scale Vec using given Cnt scale factor.
+      if (Cnt > 1) {
+        ElementCount EC = cast<VectorType>(Vec->getType())->getElementCount();
+        switch (RdxKind) {
+        case RecurKind::Add: {
+          if (ScalarTy == Builder.getInt1Ty() && ScalarTy != DestTy) {
+            unsigned VF = getNumElements(Vec->getType());
+            LLVM_DEBUG(dbgs() << "SLP: ctpop " << Cnt << "of " << Vec
+                              << ". (HorRdx)\n");
+            SmallVector<int> Mask(Cnt * VF, PoisonMaskElem);
+            for (unsigned I : seq<unsigned>(Cnt))
+              std::iota(std::next(Mask.begin(), VF * I),
+                        std::next(Mask.begin(), VF * (I + 1)), 0);
+            ++NumVectorInstructions;
+            Vec = Builder.CreateShuffleVector(Vec, Mask);
+            break;
+          }
+          // res = mul vv, n
+          Value *Scale =
+              ConstantVector::getSplat(EC, ConstantInt::get(ScalarTy, Cnt));
+          LLVM_DEBUG(dbgs() << "SLP: Add (to-mul) " << Cnt << "of " << Vec
+                            << ". (HorRdx)\n");
+          ++NumVectorInstructions;
+          Vec = Builder.CreateMul(Vec, Scale);
+          break;
+        }
+        case RecurKind::Xor: {
+          // res = n % 2 ? 0 : vv
+          LLVM_DEBUG(dbgs()
+                     << "SLP: Xor " << Cnt << "of " << Vec << ". (HorRdx)\n");
+          if (Cnt % 2 == 0)
+            Vec = Constant::getNullValue(Vec->getType());
+          break;
+        }
+        case RecurKind::FAdd: {
+          // res = fmul v, n
+          Value *Scale =
+              ConstantVector::getSplat(EC, ConstantFP::get(ScalarTy, Cnt));
+          LLVM_DEBUG(dbgs() << "SLP: FAdd (to-fmul) " << Cnt << "of " << Vec
+                            << ". (HorRdx)\n");
+          ++NumVectorInstructions;
+          Vec = Builder.CreateFMul(Vec, Scale);
+          break;
+        }
+        case RecurKind::And:
+        case RecurKind::Or:
+        case RecurKind::SMax:
+        case RecurKind::SMin:
+        case RecurKind::UMax:
+        case RecurKind::UMin:
+        case RecurKind::FMax:
+        case RecurKind::FMin:
+        case RecurKind::FMaximum:
+        case RecurKind::FMinimum:
+          // res = vv
+          break;
+        case RecurKind::Mul:
+        case RecurKind::FMul:
+        case RecurKind::FMulAdd:
+        case RecurKind::IAnyOf:
+        case RecurKind::FAnyOf:
+        case RecurKind::None:
+          llvm_unreachable("Unexpected reduction kind for repeated scalar.");
+        }
+      }
+      // Combine Vec with the previous VecOp.
+      Value *&VecOp = VecOps[Vec->getType()];
+      if (!VecOp) {
+        VecOp = Vec;
+      } else {
+        ++NumVectorInstructions;
+        if (ScalarTy == Builder.getInt1Ty() && ScalarTy != DestTy) {
+          // Handle ctpop.
+          SmallVector<int> Mask(getNumElements(VecOp->getType()) +
+                                    getNumElements(Vec->getType()),
+                                PoisonMaskElem);
+          std::iota(Mask.begin(), Mask.end(), 0);
+          VecOp = Builder.CreateShuffleVector(VecOp, Vec, Mask, "rdx.op");
+          return;
+        }
+        VecOp = createOp(Builder, RdxKind, VecOp, Vec, "rdx.op", ReductionOps);
+      }
+    };
+    for (auto [Vec, Scale, IsSigned] : VectorValuesAndScales) {
+      Value *V = Vec;
+      SmallVector<Value *> Regs = SplitVector(V);
+      for (Value *RegVec : Regs)
+        CreateVecOp(RegVec, Scale);
+    }
+    // Find minimal vector types.
+    SmallVector<std::pair<unsigned, SmallVector<Value *>>> MinVectors;
+    auto AddToMinVectors = [&](unsigned VF, Value *V) {
+      for (auto &P : MinVectors) {
+        if (VF == P.first) {
+          P.second.push_back(V);
+          return;
+        }
+        if (VF < P.first && P.first % VF == 0) {
+          P.first = VF;
+          P.second.push_back(V);
+          return;
+        }
+      }
+      MinVectors.emplace_back(VF, SmallVector<Value *>()).second.push_back(V);
+    };
+    for (auto &P : VecOps) {
+      Value *V = P.second;
+      unsigned VF = getNumElements(P.first);
+      if (isNotFullVectorType(TTI, P.first)) {
+        auto *It =
+            find_if(MinVectors, [&](const auto &P) { return P.first == VF; });
+        if (It == MinVectors.end())
+          MinVectors.emplace_back(VF, SmallVector<Value *>())
+              .second.push_back(V);
+        else
+          It->second.push_back(V);
+        continue;
+      }
+      AddToMinVectors(VF, V);
+    }
+    VecOps.clear();
+    for (auto &P : MinVectors) {
+      const unsigned VF = P.first;
+      for (Value *Vec : P.second) {
+        unsigned VecVF = getNumElements(Vec->getType());
+        if (VecVF == VF) {
+          CreateVecOp(Vec, /*Cnt=*/1);
+          continue;
+        }
+        for (unsigned Part : seq<unsigned>(VecVF / VF)) {
+          Value *Ex = Builder.CreateExtractVector(
+              getWidenedType(Vec->getType()->getScalarType(), VF), Vec,
+              Builder.getInt64(Part * VF));
+          CreateVecOp(Ex, /*Cnt=*/1);
+        }
+      }
+    }
+    for (auto &P : VecOps)
+      CreateSingleOp(P.second, /*Scale=*/1, /*IsSigned=*/false);
+
+    return ReducedSubTree;
+  }
+
   /// Emit a horizontal reduction of the vectorized value.
   Value *emitReduction(Value *VectorizedValue, IRBuilderBase &Builder,
                        const TargetTransformInfo *TTI, Type *DestTy) {

>From 708daae5244915e6f104d467d52bbe48d14cae78 Mon Sep 17 00:00:00 2001
From: Alexey Bataev <a.bataev at outlook.com>
Date: Mon, 2 Dec 2024 13:47:24 +0000
Subject: [PATCH 2/2] Fix formatting

Created using spr 1.3.5
---
 llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index 5df21b77643746..38ccf4923bbdd2 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -19238,8 +19238,7 @@ class HorizontalReduction {
     case RecurKind::Or:
       if (UseSelect && OpTy == CmpInst::makeCmpResultType(OpTy))
         return Builder.CreateSelect(
-            LHS,
-            ConstantInt::getAllOnesValue(CmpInst::makeCmpResultType(OpTy)),
+            LHS, ConstantInt::getAllOnesValue(CmpInst::makeCmpResultType(OpTy)),
             RHS, Name);
       return Builder.CreateBinOp((Instruction::BinaryOps)RdxOpcode, LHS, RHS,
                                  Name);
@@ -20150,7 +20149,8 @@ class HorizontalReduction {
                                Lane->getType()->getScalarType() !=
                                        VL.front()->getType()->getScalarType()
                                    ? V.isSignedMinBitwidthRootNode()
-                                   : true, RdxRootInst->getType());
+                                   : true,
+                               RdxRootInst->getType());
             ReducedSubTree =
                 Builder.CreateInsertElement(ReducedSubTree, Val, I);
           }



More information about the llvm-commits mailing list