[llvm] 6dba5f6 - [TTI] Align optional FMFs in getExtendedReductionCost() to getArithmeticReductionCost(). (#131968)

via llvm-commits llvm-commits at lists.llvm.org
Wed Mar 19 03:53:41 PDT 2025


Author: Elvis Wang
Date: 2025-03-19T18:53:38+08:00
New Revision: 6dba5f659567af21a01224845077a269380d8e3a

URL: https://github.com/llvm/llvm-project/commit/6dba5f659567af21a01224845077a269380d8e3a
DIFF: https://github.com/llvm/llvm-project/commit/6dba5f659567af21a01224845077a269380d8e3a.diff

LOG: [TTI] Align optional FMFs in getExtendedReductionCost() to getArithmeticReductionCost(). (#131968)

In the implementation of the getExtendedReductionCost(), it ofter calls
getArithmeticReductionCost() with FMFs. But we shouldn't call
getArithmeticReductionCost() with FMFs for non-floating-point reductions
which will return the wrong cost.

This patch makes FMFs in getExtendedReductionCost() optional and align
to the getArithmeticReductionCost(). So the TTI will return the correct
cost for non-FP extended-reductions query without FMFs.

This patch is not quite NFC but it's hard to test from the CostModel
side.

Split from #113903.

Added: 
    

Modified: 
    llvm/include/llvm/Analysis/TargetTransformInfo.h
    llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
    llvm/include/llvm/CodeGen/BasicTTIImpl.h
    llvm/lib/Analysis/TargetTransformInfo.cpp
    llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
    llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
    llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
    llvm/lib/Target/ARM/ARMTargetTransformInfo.h
    llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
    llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 4d311e7e9fd6a..7ae4913d7c037 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -1622,7 +1622,7 @@ class TargetTransformInfo {
   /// ResTy vecreduce.opcode(ext(Ty A)).
   InstructionCost getExtendedReductionCost(
       unsigned Opcode, bool IsUnsigned, Type *ResTy, VectorType *Ty,
-      FastMathFlags FMF,
+      std::optional<FastMathFlags> FMF,
       TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput) const;
 
   /// \returns The cost of Intrinsic instructions. Analyses the real arguments.
@@ -2267,7 +2267,7 @@ class TargetTransformInfo::Concept {
                          TTI::TargetCostKind CostKind) = 0;
   virtual InstructionCost getExtendedReductionCost(
       unsigned Opcode, bool IsUnsigned, Type *ResTy, VectorType *Ty,
-      FastMathFlags FMF,
+      std::optional<FastMathFlags> FMF,
       TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput) = 0;
   virtual InstructionCost getMulAccReductionCost(
       bool IsUnsigned, Type *ResTy, VectorType *Ty,
@@ -3023,7 +3023,7 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
   }
   InstructionCost
   getExtendedReductionCost(unsigned Opcode, bool IsUnsigned, Type *ResTy,
-                           VectorType *Ty, FastMathFlags FMF,
+                           VectorType *Ty, std::optional<FastMathFlags> FMF,
                            TTI::TargetCostKind CostKind) override {
     return Impl.getExtendedReductionCost(Opcode, IsUnsigned, ResTy, Ty, FMF,
                                          CostKind);

diff  --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index c15694916a732..31ab919080744 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -885,7 +885,7 @@ class TargetTransformInfoImplBase {
 
   InstructionCost getExtendedReductionCost(unsigned Opcode, bool IsUnsigned,
                                            Type *ResTy, VectorType *Ty,
-                                           FastMathFlags FMF,
+                                           std::optional<FastMathFlags> FMF,
                                            TTI::TargetCostKind CostKind) const {
     return 1;
   }

diff  --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
index d46859bcb0517..eacf75c24695f 100644
--- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h
+++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
@@ -3041,7 +3041,7 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
 
   InstructionCost getExtendedReductionCost(unsigned Opcode, bool IsUnsigned,
                                            Type *ResTy, VectorType *Ty,
-                                           FastMathFlags FMF,
+                                           std::optional<FastMathFlags> FMF,
                                            TTI::TargetCostKind CostKind) {
     if (auto *FTy = dyn_cast<FixedVectorType>(Ty);
         FTy && IsUnsigned && Opcode == Instruction::Add &&
@@ -3050,7 +3050,8 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
       // ZExtOrTrunc(ctpop(bitcast <n x i1> to in)).
       auto *IntTy =
           IntegerType::get(ResTy->getContext(), FTy->getNumElements());
-      IntrinsicCostAttributes ICA(Intrinsic::ctpop, IntTy, {IntTy}, FMF);
+      IntrinsicCostAttributes ICA(Intrinsic::ctpop, IntTy, {IntTy},
+                                  FMF ? *FMF : FastMathFlags());
       return thisT()->getCastInstrCost(Instruction::BitCast, IntTy, FTy,
                                        TTI::CastContextHint::None, CostKind) +
              thisT()->getIntrinsicInstrCost(ICA, CostKind);

diff  --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index 36f2983390a48..bd1312d8c2d0b 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -1251,7 +1251,7 @@ InstructionCost TargetTransformInfo::getMinMaxReductionCost(
 
 InstructionCost TargetTransformInfo::getExtendedReductionCost(
     unsigned Opcode, bool IsUnsigned, Type *ResTy, VectorType *Ty,
-    FastMathFlags FMF, TTI::TargetCostKind CostKind) const {
+    std::optional<FastMathFlags> FMF, TTI::TargetCostKind CostKind) const {
   return TTIImpl->getExtendedReductionCost(Opcode, IsUnsigned, ResTy, Ty, FMF,
                                            CostKind);
 }

diff  --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index 7cec8a17dfaaa..a3bf8c53571f7 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -4774,7 +4774,7 @@ AArch64TTIImpl::getArithmeticReductionCost(unsigned Opcode, VectorType *ValTy,
 
 InstructionCost AArch64TTIImpl::getExtendedReductionCost(
     unsigned Opcode, bool IsUnsigned, Type *ResTy, VectorType *VecTy,
-    FastMathFlags FMF, TTI::TargetCostKind CostKind) {
+    std::optional<FastMathFlags> FMF, TTI::TargetCostKind CostKind) {
   EVT VecVT = TLI->getValueType(DL, VecTy);
   EVT ResVT = TLI->getValueType(DL, ResTy);
 

diff  --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
index 8a3fd11705640..4d69c1e5bc732 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
@@ -426,7 +426,7 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
 
   InstructionCost getExtendedReductionCost(unsigned Opcode, bool IsUnsigned,
                                            Type *ResTy, VectorType *ValTy,
-                                           FastMathFlags FMF,
+                                           std::optional<FastMathFlags> FMF,
                                            TTI::TargetCostKind CostKind);
 
   InstructionCost getMulAccReductionCost(

diff  --git a/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp b/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
index 785067d327271..cc8a6d9449a05 100644
--- a/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
+++ b/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
@@ -1784,7 +1784,7 @@ ARMTTIImpl::getArithmeticReductionCost(unsigned Opcode, VectorType *ValTy,
 
 InstructionCost ARMTTIImpl::getExtendedReductionCost(
     unsigned Opcode, bool IsUnsigned, Type *ResTy, VectorType *ValTy,
-    FastMathFlags FMF, TTI::TargetCostKind CostKind) {
+    std::optional<FastMathFlags> FMF, TTI::TargetCostKind CostKind) {
   EVT ValVT = TLI->getValueType(DL, ValTy);
   EVT ResVT = TLI->getValueType(DL, ResTy);
 

diff  --git a/llvm/lib/Target/ARM/ARMTargetTransformInfo.h b/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
index 3e54de5aa9bab..103d2ed1c6281 100644
--- a/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
+++ b/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
@@ -284,7 +284,7 @@ class ARMTTIImpl : public BasicTTIImplBase<ARMTTIImpl> {
                                              TTI::TargetCostKind CostKind);
   InstructionCost getExtendedReductionCost(unsigned Opcode, bool IsUnsigned,
                                            Type *ResTy, VectorType *ValTy,
-                                           FastMathFlags FMF,
+                                           std::optional<FastMathFlags> FMF,
                                            TTI::TargetCostKind CostKind);
   InstructionCost getMulAccReductionCost(bool IsUnsigned, Type *ResTy,
                                          VectorType *ValTy,

diff  --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
index 1060093043278..f49ad2f5bd20e 100644
--- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
+++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
@@ -1889,7 +1889,7 @@ RISCVTTIImpl::getArithmeticReductionCost(unsigned Opcode, VectorType *Ty,
 
 InstructionCost RISCVTTIImpl::getExtendedReductionCost(
     unsigned Opcode, bool IsUnsigned, Type *ResTy, VectorType *ValTy,
-    FastMathFlags FMF, TTI::TargetCostKind CostKind) {
+    std::optional<FastMathFlags> FMF, TTI::TargetCostKind CostKind) {
   if (isa<FixedVectorType>(ValTy) && !ST->useRVVForFixedLengthVectors())
     return BaseT::getExtendedReductionCost(Opcode, IsUnsigned, ResTy, ValTy,
                                            FMF, CostKind);

diff  --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
index ac46db5faf28d..8ffe1b08d1e26 100644
--- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
+++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
@@ -211,7 +211,7 @@ class RISCVTTIImpl : public BasicTTIImplBase<RISCVTTIImpl> {
 
   InstructionCost getExtendedReductionCost(unsigned Opcode, bool IsUnsigned,
                                            Type *ResTy, VectorType *ValTy,
-                                           FastMathFlags FMF,
+                                           std::optional<FastMathFlags> FMF,
                                            TTI::TargetCostKind CostKind);
 
   InstructionCost


        


More information about the llvm-commits mailing list