[llvm] [SLP]Use getExtendedReduction cost and fix reduction cost calculations (PR #117350)

Alexey Bataev via llvm-commits llvm-commits at lists.llvm.org
Fri Nov 22 09:33:52 PST 2024


https://github.com/alexey-bataev created https://github.com/llvm/llvm-project/pull/117350

Patch uses getExtendedReduction for reductions of ext-based nodes + adds
cost estimation for ctpop-kind reductions into basic implementation and
RISCV-V specific vcpop cost estimation.


>From 6e8a8f667dd4e0a2d94f7c13c3ffa661a21ece7e Mon Sep 17 00:00:00 2001
From: Alexey Bataev <a.bataev at outlook.com>
Date: Fri, 22 Nov 2024 17:33:41 +0000
Subject: [PATCH] =?UTF-8?q?[=F0=9D=98=80=F0=9D=97=BD=F0=9D=97=BF]=20initia?=
 =?UTF-8?q?l=20version?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

Created using spr 1.3.5
---
 llvm/include/llvm/CodeGen/BasicTTIImpl.h      |  12 ++
 .../Target/RISCV/RISCVTargetTransformInfo.cpp |   8 +
 .../Transforms/Vectorize/SLPVectorizer.cpp    | 151 ++++++++++++------
 .../SLPVectorizer/RISCV/reductions.ll         |  18 +--
 .../remark-zext-incoming-for-neg-icmp.ll      |   2 +-
 5 files changed, 125 insertions(+), 66 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
index b3583e2819ee4c..d4bd504bf9ba31 100644
--- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h
+++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
@@ -2765,6 +2765,18 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
                                            Type *ResTy, VectorType *Ty,
                                            FastMathFlags FMF,
                                            TTI::TargetCostKind CostKind) {
+    if (auto *FTy = dyn_cast<FixedVectorType>(Ty);
+        FTy && Opcode == Instruction::Add &&
+        FTy->getElementType() == IntegerType::getInt1Ty(Ty->getContext())) {
+      // Represent vector_reduce_add(ZExt(<n x i1>)) as
+      // ZExtOrTrunc(ctpop(bitcast <n x i1> to in)).
+      auto *IntTy =
+          IntegerType::get(ResTy->getContext(), FTy->getNumElements());
+      IntrinsicCostAttributes ICA(Intrinsic::ctpop, IntTy, {IntTy}, FMF);
+      return thisT()->getCastInstrCost(Instruction::BitCast, IntTy, FTy,
+                                       TTI::CastContextHint::None, CostKind) +
+             thisT()->getIntrinsicInstrCost(ICA, CostKind);
+    }
     // Without any native support, this is equivalent to the cost of
     // vecreduce.opcode(ext(Ty A)).
     VectorType *ExtTy = VectorType::get(ResTy, Ty);
diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
index 2b16dcbcd8695b..026b1e694e6a64 100644
--- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
+++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
@@ -1620,6 +1620,14 @@ InstructionCost RISCVTTIImpl::getExtendedReductionCost(
 
   std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(ValTy);
 
+  if (Opcode == Instruction::Add && LT.second.isFixedLengthVector() &&
+      LT.second.getScalarType() == MVT::i1) {
+    // Represent vector_reduce_add(ZExt(<n x i1>)) as
+    // ZExtOrTrunc(ctpop(bitcast <n x i1> to in)).
+    return LT.first *
+           getRISCVInstructionCost(RISCV::VCPOP_M, LT.second, CostKind);
+  }
+
   if (ResTy->getScalarSizeInBits() != 2 * LT.second.getScalarSizeInBits())
     return BaseT::getExtendedReductionCost(Opcode, IsUnsigned, ResTy, ValTy,
                                            FMF, CostKind);
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index 8e0ca2677bf0a9..46ae908f57ab89 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -1371,22 +1371,46 @@ class BoUpSLP {
     return VectorizableTree.front()->Scalars;
   }
 
+  /// Returns the type/is-signed info for the root node in the graph without
+  /// casting.
+  std::optional<std::pair<Type *, bool>> getRootNodeTypeWithNoCast() const {
+    const TreeEntry &Root = *VectorizableTree.front().get();
+    if (Root.State != TreeEntry::Vectorize || Root.isAltShuffle() ||
+        !Root.Scalars.front()->getType()->isIntegerTy())
+      return std::nullopt;
+    auto It = MinBWs.find(&Root);
+    if (It != MinBWs.end())
+      return std::make_pair(IntegerType::get(Root.Scalars.front()->getContext(),
+                                             It->second.first),
+                            It->second.second);
+    if (Root.getOpcode() == Instruction::ZExt ||
+        Root.getOpcode() == Instruction::SExt)
+      return std::make_pair(cast<CastInst>(Root.getMainOp())->getSrcTy(),
+                            Root.getOpcode() == Instruction::SExt);
+    return std::nullopt;
+  }
+
   /// Checks if the root graph node can be emitted with narrower bitwidth at
   /// codegen and returns it signedness, if so.
   bool isSignedMinBitwidthRootNode() const {
     return MinBWs.at(VectorizableTree.front().get()).second;
   }
 
-  /// Returns reduction bitwidth and signedness, if it does not match the
-  /// original requested size.
-  std::optional<std::pair<unsigned, bool>> getReductionBitWidthAndSign() const {
+  /// Returns reduction type after minbitdth analysis.
+  FixedVectorType *getReductionType() const {
     if (ReductionBitWidth == 0 ||
+        !VectorizableTree.front()->Scalars.front()->getType()->isIntegerTy() ||
         ReductionBitWidth >=
             DL->getTypeSizeInBits(
                 VectorizableTree.front()->Scalars.front()->getType()))
-      return std::nullopt;
-    return std::make_pair(ReductionBitWidth,
-                          MinBWs.at(VectorizableTree.front().get()).second);
+      return getWidenedType(
+          VectorizableTree.front()->Scalars.front()->getType(),
+          VectorizableTree.front()->getVectorFactor());
+    return getWidenedType(
+        IntegerType::get(
+            VectorizableTree.front()->Scalars.front()->getContext(),
+            ReductionBitWidth),
+        VectorizableTree.front()->getVectorFactor());
   }
 
   /// Builds external uses of the vectorized scalars, i.e. the list of
@@ -11297,6 +11321,20 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
         return CommonCost;
       auto *VI = VL0->getOpcode() == Opcode ? VL0 : nullptr;
       TTI::CastContextHint CCH = GetCastContextHint(VL0->getOperand(0));
+
+      bool IsArithmeticExtendedReduction =
+          E->Idx == 0 && UserIgnoreList &&
+          all_of(*UserIgnoreList, [](Value *V) {
+            auto *I = cast<Instruction>(V);
+            return is_contained({Instruction::Add, Instruction::FAdd,
+                                 Instruction::Mul, Instruction::FMul,
+                                 Instruction::And, Instruction::Or,
+                                 Instruction::Xor},
+                                I->getOpcode());
+          });
+      if (IsArithmeticExtendedReduction &&
+          (VecOpcode == Instruction::ZExt || VecOpcode == Instruction::SExt))
+        return CommonCost;
       return CommonCost +
              TTI->getCastInstrCost(VecOpcode, VecTy, SrcVecTy, CCH, CostKind,
                                    VecOpcode == Opcode ? VI : nullptr);
@@ -12652,32 +12690,48 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) {
       unsigned SrcSize = It->second.first;
       unsigned DstSize = ReductionBitWidth;
       unsigned Opcode = Instruction::Trunc;
-      if (SrcSize < DstSize)
-        Opcode = It->second.second ? Instruction::SExt : Instruction::ZExt;
-      auto *SrcVecTy =
-          getWidenedType(Builder.getIntNTy(SrcSize), E.getVectorFactor());
-      auto *DstVecTy =
-          getWidenedType(Builder.getIntNTy(DstSize), E.getVectorFactor());
-      TTI::CastContextHint CCH = getCastContextHint(E);
-      InstructionCost CastCost;
-      switch (E.getOpcode()) {
-      case Instruction::SExt:
-      case Instruction::ZExt:
-      case Instruction::Trunc: {
-        const TreeEntry *OpTE = getOperandEntry(&E, 0);
-        CCH = getCastContextHint(*OpTE);
-        break;
-      }
-      default:
-        break;
+      if (SrcSize < DstSize) {
+        bool IsArithmeticExtendedReduction =
+            all_of(*UserIgnoreList, [](Value *V) {
+              auto *I = cast<Instruction>(V);
+              return is_contained({Instruction::Add, Instruction::FAdd,
+                                   Instruction::Mul, Instruction::FMul,
+                                   Instruction::And, Instruction::Or,
+                                   Instruction::Xor},
+                                  I->getOpcode());
+            });
+        if (IsArithmeticExtendedReduction)
+          Opcode =
+              Instruction::BitCast; // Handle it by getExtendedReductionCost
+        else
+          Opcode = It->second.second ? Instruction::SExt : Instruction::ZExt;
+      }
+      if (Opcode != Instruction::BitCast) {
+        auto *SrcVecTy =
+            getWidenedType(Builder.getIntNTy(SrcSize), E.getVectorFactor());
+        auto *DstVecTy =
+            getWidenedType(Builder.getIntNTy(DstSize), E.getVectorFactor());
+        TTI::CastContextHint CCH = getCastContextHint(E);
+        InstructionCost CastCost;
+        switch (E.getOpcode()) {
+        case Instruction::SExt:
+        case Instruction::ZExt:
+        case Instruction::Trunc: {
+          const TreeEntry *OpTE = getOperandEntry(&E, 0);
+          CCH = getCastContextHint(*OpTE);
+          break;
+        }
+        default:
+          break;
+        }
+        CastCost += TTI->getCastInstrCost(Opcode, DstVecTy, SrcVecTy, CCH,
+                                          TTI::TCK_RecipThroughput);
+        Cost += CastCost;
+        LLVM_DEBUG(dbgs() << "SLP: Adding cost " << CastCost
+                          << " for final resize for reduction from " << SrcVecTy
+                          << " to " << DstVecTy << "\n";
+                   dbgs() << "SLP: Current total cost = " << Cost << "\n");
       }
-      CastCost += TTI->getCastInstrCost(Opcode, DstVecTy, SrcVecTy, CCH,
-                                        TTI::TCK_RecipThroughput);
-      Cost += CastCost;
-      LLVM_DEBUG(dbgs() << "SLP: Adding cost " << CastCost
-                        << " for final resize for reduction from " << SrcVecTy
-                        << " to " << DstVecTy << "\n";
-                 dbgs() << "SLP: Current total cost = " << Cost << "\n");
     }
   }
 
@@ -19815,8 +19869,8 @@ class HorizontalReduction {
 
         // Estimate cost.
         InstructionCost TreeCost = V.getTreeCost(VL);
-        InstructionCost ReductionCost = getReductionCost(
-            TTI, VL, IsCmpSelMinMax, RdxFMF, V.getReductionBitWidthAndSign());
+        InstructionCost ReductionCost =
+            getReductionCost(TTI, VL, IsCmpSelMinMax, RdxFMF, V);
         InstructionCost Cost = TreeCost + ReductionCost;
         LLVM_DEBUG(dbgs() << "SLP: Found cost = " << Cost
                           << " for reduction\n");
@@ -20107,14 +20161,14 @@ class HorizontalReduction {
 
 private:
   /// Calculate the cost of a reduction.
-  InstructionCost getReductionCost(
-      TargetTransformInfo *TTI, ArrayRef<Value *> ReducedVals,
-      bool IsCmpSelMinMax, FastMathFlags FMF,
-      const std::optional<std::pair<unsigned, bool>> BitwidthAndSign) {
+  InstructionCost getReductionCost(TargetTransformInfo *TTI,
+                                   ArrayRef<Value *> ReducedVals,
+                                   bool IsCmpSelMinMax, FastMathFlags FMF,
+                                   const BoUpSLP &R) {
     TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
     Type *ScalarTy = ReducedVals.front()->getType();
     unsigned ReduxWidth = ReducedVals.size();
-    FixedVectorType *VectorTy = getWidenedType(ScalarTy, ReduxWidth);
+    FixedVectorType *VectorTy = R.getReductionType();
     InstructionCost VectorCost = 0, ScalarCost;
     // If all of the reduced values are constant, the vector cost is 0, since
     // the reduction value can be calculated at the compile time.
@@ -20172,21 +20226,16 @@ class HorizontalReduction {
               VecTy, APInt::getAllOnes(ScalarTyNumElements), /*Insert*/ true,
               /*Extract*/ false, TTI::TCK_RecipThroughput);
         } else {
-          auto [Bitwidth, IsSigned] =
-              BitwidthAndSign.value_or(std::make_pair(0u, false));
-          if (RdxKind == RecurKind::Add && Bitwidth == 1) {
-            // Represent vector_reduce_add(ZExt(<n x i1>)) to
-            // ZExtOrTrunc(ctpop(bitcast <n x i1> to in)).
-            auto *IntTy = IntegerType::get(ScalarTy->getContext(), ReduxWidth);
-            IntrinsicCostAttributes ICA(Intrinsic::ctpop, IntTy, {IntTy}, FMF);
-            VectorCost =
-                TTI->getCastInstrCost(Instruction::BitCast, IntTy,
-                                      getWidenedType(ScalarTy, ReduxWidth),
-                                      TTI::CastContextHint::None, CostKind) +
-                TTI->getIntrinsicInstrCost(ICA, CostKind);
-          } else {
+          Type *RedTy = VectorTy->getElementType();
+          auto [RType, IsSigned] = R.getRootNodeTypeWithNoCast().value_or(
+              std::make_pair(RedTy, true));
+          if (RType == RedTy) {
             VectorCost = TTI->getArithmeticReductionCost(RdxOpcode, VectorTy,
                                                          FMF, CostKind);
+          } else {
+            VectorCost = TTI->getExtendedReductionCost(
+                RdxOpcode, !IsSigned, RedTy, getWidenedType(RType, ReduxWidth),
+                FMF, CostKind);
           }
         }
       }
diff --git a/llvm/test/Transforms/SLPVectorizer/RISCV/reductions.ll b/llvm/test/Transforms/SLPVectorizer/RISCV/reductions.ll
index bc24a44cecbe39..85131758853b3d 100644
--- a/llvm/test/Transforms/SLPVectorizer/RISCV/reductions.ll
+++ b/llvm/test/Transforms/SLPVectorizer/RISCV/reductions.ll
@@ -877,20 +877,10 @@ entry:
 define i64 @red_zext_ld_4xi64(ptr %ptr) {
 ; CHECK-LABEL: @red_zext_ld_4xi64(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[LD0:%.*]] = load i8, ptr [[PTR:%.*]], align 1
-; CHECK-NEXT:    [[ZEXT:%.*]] = zext i8 [[LD0]] to i64
-; CHECK-NEXT:    [[GEP:%.*]] = getelementptr inbounds i8, ptr [[PTR]], i64 1
-; CHECK-NEXT:    [[LD1:%.*]] = load i8, ptr [[GEP]], align 1
-; CHECK-NEXT:    [[ZEXT_1:%.*]] = zext i8 [[LD1]] to i64
-; CHECK-NEXT:    [[ADD_1:%.*]] = add nuw nsw i64 [[ZEXT]], [[ZEXT_1]]
-; CHECK-NEXT:    [[GEP_1:%.*]] = getelementptr inbounds i8, ptr [[PTR]], i64 2
-; CHECK-NEXT:    [[LD2:%.*]] = load i8, ptr [[GEP_1]], align 1
-; CHECK-NEXT:    [[ZEXT_2:%.*]] = zext i8 [[LD2]] to i64
-; CHECK-NEXT:    [[ADD_2:%.*]] = add nuw nsw i64 [[ADD_1]], [[ZEXT_2]]
-; CHECK-NEXT:    [[GEP_2:%.*]] = getelementptr inbounds i8, ptr [[PTR]], i64 3
-; CHECK-NEXT:    [[LD3:%.*]] = load i8, ptr [[GEP_2]], align 1
-; CHECK-NEXT:    [[ZEXT_3:%.*]] = zext i8 [[LD3]] to i64
-; CHECK-NEXT:    [[ADD_3:%.*]] = add nuw nsw i64 [[ADD_2]], [[ZEXT_3]]
+; CHECK-NEXT:    [[TMP0:%.*]] = load <4 x i8>, ptr [[PTR:%.*]], align 1
+; CHECK-NEXT:    [[TMP1:%.*]] = zext <4 x i8> [[TMP0]] to <4 x i16>
+; CHECK-NEXT:    [[TMP2:%.*]] = call i16 @llvm.vector.reduce.add.v4i16(<4 x i16> [[TMP1]])
+; CHECK-NEXT:    [[ADD_3:%.*]] = zext i16 [[TMP2]] to i64
 ; CHECK-NEXT:    ret i64 [[ADD_3]]
 ;
 entry:
diff --git a/llvm/test/Transforms/SLPVectorizer/RISCV/remark-zext-incoming-for-neg-icmp.ll b/llvm/test/Transforms/SLPVectorizer/RISCV/remark-zext-incoming-for-neg-icmp.ll
index e4d20a6db8fa67..09c11bbefd4a35 100644
--- a/llvm/test/Transforms/SLPVectorizer/RISCV/remark-zext-incoming-for-neg-icmp.ll
+++ b/llvm/test/Transforms/SLPVectorizer/RISCV/remark-zext-incoming-for-neg-icmp.ll
@@ -8,7 +8,7 @@
 ; YAML-NEXT: Function:        test
 ; YAML-NEXT: Args:
 ; YAML-NEXT:   - String:          'Vectorized horizontal reduction with cost '
-; YAML-NEXT:   - Cost:            '-1'
+; YAML-NEXT:   - Cost:            '-10'
 ; YAML-NEXT:   - String:          ' and with tree size '
 ; YAML-NEXT:   - TreeSize:        '8'
 ; YAML-NEXT:...



More information about the llvm-commits mailing list