[llvm] [SLP]Model reduction_add(ext(<n x i1>)) as ext(ctpop(bitcast <n x i1> to int n)) (PR #116875)

via llvm-commits llvm-commits at lists.llvm.org
Tue Nov 19 12:54:57 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-transforms

Author: Alexey Bataev (alexey-bataev)

<details>
<summary>Changes</summary>

Currently sequences reduction_add(ext(<n x i1>)) are modeled as vector
extensions + reduction add, but later instcombiner transforms it into
ext(ctcpop(bitcast <n x i1> to int n)). Patch adds direct support for
this in SLP vectorizer, which enables better cost estimation.

AVX512, -O3+LTO

CINT2006/445.gobmk - extra vector code
Prolangs-C/bison - extra vector code
Benchmarks/NPB-serial/is - 16 x + 8 x reductions vectorized as 24
x reduction


---
Full diff: https://github.com/llvm/llvm-project/pull/116875.diff


3 Files Affected:

- (modified) llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp (+80-27) 
- (modified) llvm/test/Transforms/SLPVectorizer/X86/alternate-cmp-swapped-pred.ll (+3-2) 
- (modified) llvm/test/Transforms/SLPVectorizer/zext-incoming-for-neg-icmp.ll (+3-2) 


``````````diff
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index e70627b6afc10d..fe5099d68024c3 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -1371,6 +1371,18 @@ class BoUpSLP {
     return MinBWs.at(VectorizableTree.front().get()).second;
   }
 
+  /// Returns reduction bitwidth and signedness, if it does not match the
+  /// original requested size.
+  std::optional<std::pair<unsigned, bool>> getReductionBitWidthAndSign() const {
+    if (ReductionBitWidth == 0 ||
+        ReductionBitWidth ==
+            DL->getTypeSizeInBits(
+                VectorizableTree.front()->Scalars.front()->getType()))
+      return std::nullopt;
+    return std::make_pair(ReductionBitWidth,
+                          MinBWs.at(VectorizableTree.front().get()).second);
+  }
+
   /// Builds external uses of the vectorized scalars, i.e. the list of
   /// vectorized scalars to be extracted, their lanes and their scalar users. \p
   /// ExternallyUsedValues contains additional list of external uses to handle
@@ -17885,24 +17897,37 @@ void BoUpSLP::computeMinimumValueSizes() {
   // Add reduction ops sizes, if any.
   if (UserIgnoreList &&
       isa<IntegerType>(VectorizableTree.front()->Scalars.front()->getType())) {
-    for (Value *V : *UserIgnoreList) {
-      auto NumSignBits = ComputeNumSignBits(V, *DL, 0, AC, nullptr, DT);
-      auto NumTypeBits = DL->getTypeSizeInBits(V->getType());
-      unsigned BitWidth1 = NumTypeBits - NumSignBits;
-      if (!isKnownNonNegative(V, SimplifyQuery(*DL)))
-        ++BitWidth1;
-      unsigned BitWidth2 = BitWidth1;
-      if (!RecurrenceDescriptor::isIntMinMaxRecurrenceKind(::getRdxKind(V))) {
-        auto Mask = DB->getDemandedBits(cast<Instruction>(V));
-        BitWidth2 = Mask.getBitWidth() - Mask.countl_zero();
+      // Convert vector_reduce_add(ZExt(<n x i1>)) to
+      // ZExtOrTrunc(ctpop(bitcast <n x i1> to in)).
+    if (all_of(*UserIgnoreList,
+               [](Value *V) {
+                 return cast<Instruction>(V)->getOpcode() == Instruction::Add;
+               }) &&
+        VectorizableTree.front()->State == TreeEntry::Vectorize &&
+        VectorizableTree.front()->getOpcode() == Instruction::ZExt &&
+        cast<CastInst>(VectorizableTree.front()->getMainOp())->getSrcTy() ==
+            Builder.getInt1Ty()) {
+      ReductionBitWidth = 1;
+    } else {
+      for (Value *V : *UserIgnoreList) {
+        auto NumSignBits = ComputeNumSignBits(V, *DL, 0, AC, nullptr, DT);
+        auto NumTypeBits = DL->getTypeSizeInBits(V->getType());
+        unsigned BitWidth1 = NumTypeBits - NumSignBits;
+        if (!isKnownNonNegative(V, SimplifyQuery(*DL)))
+          ++BitWidth1;
+        unsigned BitWidth2 = BitWidth1;
+        if (!RecurrenceDescriptor::isIntMinMaxRecurrenceKind(::getRdxKind(V))) {
+          auto Mask = DB->getDemandedBits(cast<Instruction>(V));
+          BitWidth2 = Mask.getBitWidth() - Mask.countl_zero();
+        }
+        ReductionBitWidth =
+            std::max(std::min(BitWidth1, BitWidth2), ReductionBitWidth);
       }
-      ReductionBitWidth =
-          std::max(std::min(BitWidth1, BitWidth2), ReductionBitWidth);
-    }
-    if (ReductionBitWidth < 8 && ReductionBitWidth > 1)
-      ReductionBitWidth = 8;
+      if (ReductionBitWidth < 8 && ReductionBitWidth > 1)
+        ReductionBitWidth = 8;
 
-    ReductionBitWidth = bit_ceil(ReductionBitWidth);
+      ReductionBitWidth = bit_ceil(ReductionBitWidth);
+    }
   }
   bool IsTopRoot = NodeIdx == 0;
   while (NodeIdx < VectorizableTree.size() &&
@@ -19758,8 +19783,8 @@ class HorizontalReduction {
 
         // Estimate cost.
         InstructionCost TreeCost = V.getTreeCost(VL);
-        InstructionCost ReductionCost =
-            getReductionCost(TTI, VL, IsCmpSelMinMax, ReduxWidth, RdxFMF);
+        InstructionCost ReductionCost = getReductionCost(
+            TTI, VL, IsCmpSelMinMax, RdxFMF, V.getReductionBitWidthAndSign());
         InstructionCost Cost = TreeCost + ReductionCost;
         LLVM_DEBUG(dbgs() << "SLP: Found cost = " << Cost
                           << " for reduction\n");
@@ -19864,10 +19889,12 @@ class HorizontalReduction {
                 createStrideMask(I, ScalarTyNumElements, VL.size());
             Value *Lane = Builder.CreateShuffleVector(VectorizedRoot, Mask);
             ReducedSubTree = Builder.CreateInsertElement(
-                ReducedSubTree, emitReduction(Lane, Builder, TTI), I);
+                ReducedSubTree,
+                emitReduction(Lane, Builder, TTI, RdxRootInst->getType()), I);
           }
         } else {
-          ReducedSubTree = emitReduction(VectorizedRoot, Builder, TTI);
+          ReducedSubTree = emitReduction(VectorizedRoot, Builder, TTI,
+                                         RdxRootInst->getType());
         }
         if (ReducedSubTree->getType() != VL.front()->getType()) {
           assert(ReducedSubTree->getType() != VL.front()->getType() &&
@@ -20048,12 +20075,13 @@ class HorizontalReduction {
 
 private:
   /// Calculate the cost of a reduction.
-  InstructionCost getReductionCost(TargetTransformInfo *TTI,
-                                   ArrayRef<Value *> ReducedVals,
-                                   bool IsCmpSelMinMax, unsigned ReduxWidth,
-                                   FastMathFlags FMF) {
+  InstructionCost getReductionCost(
+      TargetTransformInfo *TTI, ArrayRef<Value *> ReducedVals,
+      bool IsCmpSelMinMax, FastMathFlags FMF,
+      const std::optional<std::pair<unsigned, bool>> BitwidthAndSign) {
     TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
     Type *ScalarTy = ReducedVals.front()->getType();
+    unsigned ReduxWidth = ReducedVals.size();
     FixedVectorType *VectorTy = getWidenedType(ScalarTy, ReduxWidth);
     InstructionCost VectorCost = 0, ScalarCost;
     // If all of the reduced values are constant, the vector cost is 0, since
@@ -20112,8 +20140,22 @@ class HorizontalReduction {
               VecTy, APInt::getAllOnes(ScalarTyNumElements), /*Insert*/ true,
               /*Extract*/ false, TTI::TCK_RecipThroughput);
         } else {
-          VectorCost = TTI->getArithmeticReductionCost(RdxOpcode, VectorTy, FMF,
-                                                       CostKind);
+          auto [Bitwidth, IsSigned] =
+              BitwidthAndSign.value_or(std::make_pair(0u, false));
+          if (RdxKind == RecurKind::Add && Bitwidth == 1) {
+            // Represent vector_reduce_add(ZExt(<n x i1>)) to
+            // ZExtOrTrunc(ctpop(bitcast <n x i1> to in)).
+            auto *IntTy = IntegerType::get(ScalarTy->getContext(), ReduxWidth);
+            IntrinsicCostAttributes ICA(Intrinsic::ctpop, IntTy, {IntTy}, FMF);
+            VectorCost =
+                TTI->getCastInstrCost(Instruction::BitCast, IntTy,
+                                      getWidenedType(ScalarTy, ReduxWidth),
+                                      TTI::CastContextHint::None, CostKind) +
+                TTI->getIntrinsicInstrCost(ICA, CostKind);
+          } else {
+            VectorCost = TTI->getArithmeticReductionCost(RdxOpcode, VectorTy,
+                                                         FMF, CostKind);
+          }
         }
       }
       ScalarCost = EvaluateScalarCost([&]() {
@@ -20150,11 +20192,22 @@ class HorizontalReduction {
 
   /// Emit a horizontal reduction of the vectorized value.
   Value *emitReduction(Value *VectorizedValue, IRBuilderBase &Builder,
-                       const TargetTransformInfo *TTI) {
+                       const TargetTransformInfo *TTI, Type *DestTy) {
     assert(VectorizedValue && "Need to have a vectorized tree node");
     assert(RdxKind != RecurKind::FMulAdd &&
            "A call to the llvm.fmuladd intrinsic is not handled yet");
 
+    auto *FTy = cast<FixedVectorType>(VectorizedValue->getType());
+    if (FTy->getScalarType() == Builder.getInt1Ty() &&
+        RdxKind == RecurKind::Add &&
+        DestTy->getScalarType() != FTy->getScalarType()) {
+      // Convert vector_reduce_add(ZExt(<n x i1>)) to
+      // ZExtOrTrunc(ctpop(bitcast <n x i1> to in)).
+      Value *V = Builder.CreateBitCast(
+          VectorizedValue, Builder.getIntNTy(FTy->getNumElements()));
+      ++NumVectorInstructions;
+      return Builder.CreateUnaryIntrinsic(Intrinsic::ctpop, V);
+    }
     ++NumVectorInstructions;
     return createSimpleReduction(Builder, VectorizedValue, RdxKind);
   }
diff --git a/llvm/test/Transforms/SLPVectorizer/X86/alternate-cmp-swapped-pred.ll b/llvm/test/Transforms/SLPVectorizer/X86/alternate-cmp-swapped-pred.ll
index ecf85159efdfbd..f00b846bf4f5bd 100644
--- a/llvm/test/Transforms/SLPVectorizer/X86/alternate-cmp-swapped-pred.ll
+++ b/llvm/test/Transforms/SLPVectorizer/X86/alternate-cmp-swapped-pred.ll
@@ -11,8 +11,9 @@ define i16 @test(i16 %call37) {
 ; CHECK-NEXT:    [[TMP2:%.*]] = icmp slt <8 x i16> [[SHUFFLE]], zeroinitializer
 ; CHECK-NEXT:    [[TMP3:%.*]] = icmp sgt <8 x i16> [[SHUFFLE]], zeroinitializer
 ; CHECK-NEXT:    [[TMP4:%.*]] = shufflevector <8 x i1> [[TMP2]], <8 x i1> [[TMP3]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 12, i32 5, i32 6, i32 7>
-; CHECK-NEXT:    [[TMP5:%.*]] = zext <8 x i1> [[TMP4]] to <8 x i16>
-; CHECK-NEXT:    [[TMP6:%.*]] = call i16 @llvm.vector.reduce.add.v8i16(<8 x i16> [[TMP5]])
+; CHECK-NEXT:    [[TMP8:%.*]] = bitcast <8 x i1> [[TMP4]] to i8
+; CHECK-NEXT:    [[TMP7:%.*]] = call i8 @llvm.ctpop.i8(i8 [[TMP8]])
+; CHECK-NEXT:    [[TMP6:%.*]] = zext i8 [[TMP7]] to i16
 ; CHECK-NEXT:    [[OP_RDX:%.*]] = add i16 [[TMP6]], 0
 ; CHECK-NEXT:    ret i16 [[OP_RDX]]
 ;
diff --git a/llvm/test/Transforms/SLPVectorizer/zext-incoming-for-neg-icmp.ll b/llvm/test/Transforms/SLPVectorizer/zext-incoming-for-neg-icmp.ll
index 89fcc7e983749b..303e31dfa5e64a 100644
--- a/llvm/test/Transforms/SLPVectorizer/zext-incoming-for-neg-icmp.ll
+++ b/llvm/test/Transforms/SLPVectorizer/zext-incoming-for-neg-icmp.ll
@@ -14,8 +14,9 @@ define i32 @test(i32 %a, i8 %b, i8 %c) {
 ; CHECK-NEXT:    [[TMP8:%.*]] = zext <4 x i8> [[TMP2]] to <4 x i16>
 ; CHECK-NEXT:    [[TMP9:%.*]] = sext <4 x i8> [[TMP4]] to <4 x i16>
 ; CHECK-NEXT:    [[TMP5:%.*]] = icmp sle <4 x i16> [[TMP8]], [[TMP9]]
-; CHECK-NEXT:    [[TMP6:%.*]] = zext <4 x i1> [[TMP5]] to <4 x i32>
-; CHECK-NEXT:    [[TMP7:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP6]])
+; CHECK-NEXT:    [[TMP10:%.*]] = bitcast <4 x i1> [[TMP5]] to i4
+; CHECK-NEXT:    [[TMP11:%.*]] = call i4 @llvm.ctpop.i4(i4 [[TMP10]])
+; CHECK-NEXT:    [[TMP7:%.*]] = zext i4 [[TMP11]] to i32
 ; CHECK-NEXT:    [[OP_RDX:%.*]] = add i32 [[TMP7]], [[A]]
 ; CHECK-NEXT:    ret i32 [[OP_RDX]]
 ;

``````````

</details>


https://github.com/llvm/llvm-project/pull/116875


More information about the llvm-commits mailing list