[llvm] [TTI] Plumb CostKind through getPartialReductionCost (PR #144953)

via llvm-commits llvm-commits at lists.llvm.org
Thu Jun 19 13:49:24 PDT 2025


github-actions[bot] wrote:

<!--LLVM CODE FORMAT COMMENT: {clang-format}-->


:warning: C/C++ code formatter, clang-format found issues in your code. :warning:

<details>
<summary>
You can test this locally with the following command:
</summary>

``````````bash
git-clang-format --diff HEAD~1 HEAD --extensions h,cpp -- llvm/include/llvm/Analysis/TargetTransformInfo.h llvm/include/llvm/Analysis/TargetTransformInfoImpl.h llvm/lib/Analysis/TargetTransformInfo.cpp llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.h llvm/lib/Transforms/Vectorize/LoopVectorize.cpp llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
``````````

</details>

<details>
<summary>
View the diff from clang-format here.
</summary>

``````````diff
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 8b7530299..ba47cef27 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -1332,8 +1332,8 @@ public:
   LLVM_ABI InstructionCost getPartialReductionCost(
       unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
       ElementCount VF, PartialReductionExtendKind OpAExtend,
-      PartialReductionExtendKind OpBExtend,
-      std::optional<unsigned> BinOp, TTI::TargetCostKind CostKind) const;
+      PartialReductionExtendKind OpBExtend, std::optional<unsigned> BinOp,
+      TTI::TargetCostKind CostKind) const;
 
   /// \return The maximum interleave factor that any transform should try to
   /// perform for this target. This number depends on the level of parallelism
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 8e3b91741..640766cf8 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -652,13 +652,11 @@ public:
   virtual bool enableWritePrefetching() const { return false; }
   virtual bool shouldPrefetchAddressSpace(unsigned AS) const { return !AS; }
 
-  virtual InstructionCost
-  getPartialReductionCost(unsigned Opcode, Type *InputTypeA, Type *InputTypeB,
-                          Type *AccumType, ElementCount VF,
-                          TTI::PartialReductionExtendKind OpAExtend,
-                          TTI::PartialReductionExtendKind OpBExtend,
-                          std::optional<unsigned> BinOp,
-                          TTI::TargetCostKind CostKind) const {
+  virtual InstructionCost getPartialReductionCost(
+      unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
+      ElementCount VF, TTI::PartialReductionExtendKind OpAExtend,
+      TTI::PartialReductionExtendKind OpBExtend, std::optional<unsigned> BinOp,
+      TTI::TargetCostKind CostKind) const {
     return InstructionCost::getInvalid();
   }
 
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index 39d6bc7a8..d720e9a64 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -5395,8 +5395,8 @@ AArch64TTIImpl::getSpliceCost(VectorType *Tp, int Index,
 InstructionCost AArch64TTIImpl::getPartialReductionCost(
     unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
     ElementCount VF, TTI::PartialReductionExtendKind OpAExtend,
-    TTI::PartialReductionExtendKind OpBExtend,
-    std::optional<unsigned> BinOp, TTI::TargetCostKind CostKind) const {
+    TTI::PartialReductionExtendKind OpBExtend, std::optional<unsigned> BinOp,
+    TTI::TargetCostKind CostKind) const {
   InstructionCost Invalid = InstructionCost::getInvalid();
   InstructionCost Cost(TTI::TCC_Basic);
 
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
index 7518e9e66..470af01be 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
@@ -382,13 +382,11 @@ public:
     return BaseT::isLegalNTLoad(DataType, Alignment);
   }
 
-  InstructionCost
-  getPartialReductionCost(unsigned Opcode, Type *InputTypeA, Type *InputTypeB,
-                          Type *AccumType, ElementCount VF,
-                          TTI::PartialReductionExtendKind OpAExtend,
-                          TTI::PartialReductionExtendKind OpBExtend,
-                          std::optional<unsigned> BinOp,
-                          TTI::TargetCostKind CostKind) const override;
+  InstructionCost getPartialReductionCost(
+      unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
+      ElementCount VF, TTI::PartialReductionExtendKind OpAExtend,
+      TTI::PartialReductionExtendKind OpBExtend, std::optional<unsigned> BinOp,
+      TTI::TargetCostKind CostKind) const override;
 
   bool enableOrderedReductions() const override { return true; }
 
diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
index 8d1ceb0f5..95829866e 100644
--- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
+++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
@@ -297,8 +297,8 @@ RISCVTTIImpl::getPopcntSupport(unsigned TyWidth) const {
 InstructionCost RISCVTTIImpl::getPartialReductionCost(
     unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
     ElementCount VF, TTI::PartialReductionExtendKind OpAExtend,
-    TTI::PartialReductionExtendKind OpBExtend,
-    std::optional<unsigned> BinOp, TTI::TargetCostKind CostKind) const {
+    TTI::PartialReductionExtendKind OpBExtend, std::optional<unsigned> BinOp,
+    TTI::TargetCostKind CostKind) const {
 
   // zve32x is broken for partial_reduce_umla, but let's make sure we
   // don't generate them.
@@ -312,8 +312,8 @@ InstructionCost RISCVTTIImpl::getPartialReductionCost(
   std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(Tp);
   // Note: Asuming all vqdot* variants are equal cost
   // TODO: Thread CostKind through this API
-  return LT.first * getRISCVInstructionCost(RISCV::VQDOT_VV, LT.second,
-                                            CostKind);
+  return LT.first *
+         getRISCVInstructionCost(RISCV::VQDOT_VV, LT.second, CostKind);
 }
 
 bool RISCVTTIImpl::shouldExpandReduction(const IntrinsicInst *II) const {
diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
index 1fe73846b..83ac71ed9 100644
--- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
+++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
@@ -100,13 +100,11 @@ public:
   TargetTransformInfo::PopcntSupportKind
   getPopcntSupport(unsigned TyWidth) const override;
 
-  InstructionCost
-  getPartialReductionCost(unsigned Opcode, Type *InputTypeA, Type *InputTypeB,
-                          Type *AccumType, ElementCount VF,
-                          TTI::PartialReductionExtendKind OpAExtend,
-                          TTI::PartialReductionExtendKind OpBExtend,
-                          std::optional<unsigned> BinOp,
-                          TTI::TargetCostKind CostKind) const override;
+  InstructionCost getPartialReductionCost(
+      unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
+      ElementCount VF, TTI::PartialReductionExtendKind OpAExtend,
+      TTI::PartialReductionExtendKind OpBExtend, std::optional<unsigned> BinOp,
+      TTI::TargetCostKind CostKind) const override;
 
   bool shouldExpandReduction(const IntrinsicInst *II) const override;
   bool supportsScalableVectors() const override {
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp
index b9cc5c328..359084f31 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp
@@ -198,8 +198,8 @@ InstructionCost WebAssemblyTTIImpl::getVectorInstrCost(
 InstructionCost WebAssemblyTTIImpl::getPartialReductionCost(
     unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
     ElementCount VF, TTI::PartialReductionExtendKind OpAExtend,
-    TTI::PartialReductionExtendKind OpBExtend,
-    std::optional<unsigned> BinOp, TTI::TargetCostKind CostKind) const {
+    TTI::PartialReductionExtendKind OpBExtend, std::optional<unsigned> BinOp,
+    TTI::TargetCostKind CostKind) const {
   InstructionCost Invalid = InstructionCost::getInvalid();
   if (!VF.isFixed() || !ST->hasSIMD128())
     return Invalid;
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.h b/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.h
index 4e183921a..d83b8d1f4 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.h
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.h
@@ -86,8 +86,8 @@ public:
   InstructionCost getPartialReductionCost(
       unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
       ElementCount VF, TTI::PartialReductionExtendKind OpAExtend,
-      TTI::PartialReductionExtendKind OpBExtend,
-      std::optional<unsigned> BinOp, TTI::TargetCostKind CostKind) const override;
+      TTI::PartialReductionExtendKind OpBExtend, std::optional<unsigned> BinOp,
+      TTI::TargetCostKind CostKind) const override;
   TTI::ReductionShuffle
   getPreferredExpandedReductionShuffle(const IntrinsicInst *II) const override;
 
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 029a2b101..22861eb1c 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -336,9 +336,9 @@ VPPartialReductionRecipe::computeCost(ElementCount VF,
     return TargetTransformInfo::PR_None;
   };
 
-  return Ctx.TTI.getPartialReductionCost(getOpcode(), InputTypeA, InputTypeB,
-                                         PhiType, VF, GetExtendKind(ExtAR),
-                                         GetExtendKind(ExtBR), Opcode, Ctx.CostKind);
+  return Ctx.TTI.getPartialReductionCost(
+      getOpcode(), InputTypeA, InputTypeB, PhiType, VF, GetExtendKind(ExtAR),
+      GetExtendKind(ExtBR), Opcode, Ctx.CostKind);
 }
 
 void VPPartialReductionRecipe::execute(VPTransformState &State) {

``````````

</details>


https://github.com/llvm/llvm-project/pull/144953


More information about the llvm-commits mailing list