[llvm] 703d43c - [CostModel] Move default expand cost for partial reductions to BasicTTIImpl (#189905)

via llvm-commits llvm-commits at lists.llvm.org
Thu Apr 2 01:42:59 PDT 2026


Author: Sander de Smalen
Date: 2026-04-02T09:42:53+01:00
New Revision: 703d43ca3b66279b4ad81b88d6805f6f27edc557

URL: https://github.com/llvm/llvm-project/commit/703d43ca3b66279b4ad81b88d6805f6f27edc557
DIFF: https://github.com/llvm/llvm-project/commit/703d43ca3b66279b4ad81b88d6805f6f27edc557.diff

LOG: [CostModel] Move default expand cost for partial reductions to BasicTTIImpl (#189905)

This is a follow-up of the suggestion left here:

https://github.com/llvm/llvm-project/pull/181707#discussion_r2995733831

The override functions in AMDGPU/ARM/SystemZ/X86 are required to avoid
enabling partial reductions where they were previously disabled (I've
added this for all targets that implement getArithmeticReductionCost).

Added: 
    

Modified: 
    llvm/include/llvm/CodeGen/BasicTTIImpl.h
    llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
    llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.h
    llvm/lib/Target/ARM/ARMTargetTransformInfo.h
    llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.h
    llvm/lib/Target/X86/X86TargetTransformInfo.h

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
index bda6ac45ace4c..aa8f9beacac90 100644
--- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h
+++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
@@ -3441,6 +3441,49 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
     return RedCost + MulCost + 2 * ExtCost;
   }
 
+  InstructionCost getPartialReductionCost(
+      unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
+      ElementCount VF, TTI::PartialReductionExtendKind OpAExtend,
+      TTI::PartialReductionExtendKind OpBExtend, std::optional<unsigned> BinOp,
+      TTI::TargetCostKind CostKind,
+      std::optional<FastMathFlags> FMF) const override {
+    unsigned EltSizeAcc = AccumType->getScalarSizeInBits();
+    unsigned EltSizeInA = InputTypeA->getScalarSizeInBits();
+    unsigned Ratio = EltSizeAcc / EltSizeInA;
+    if (VF.getKnownMinValue() <= Ratio || VF.getKnownMinValue() % Ratio != 0 ||
+        EltSizeAcc % EltSizeInA != 0 || (BinOp && InputTypeA != InputTypeB))
+      return InstructionCost::getInvalid();
+
+    Type *InputVectorType = VectorType::get(InputTypeA, VF);
+    Type *ExtInputVectorType = VectorType::get(AccumType, VF);
+    Type *AccumVectorType =
+        VectorType::get(AccumType, VF.divideCoefficientBy(Ratio));
+
+    InstructionCost ExtendCostA = 0;
+    if (OpAExtend != TTI::PartialReductionExtendKind::PR_None)
+      ExtendCostA = getCastInstrCost(
+          TTI::getOpcodeForPartialReductionExtendKind(OpAExtend),
+          ExtInputVectorType, InputVectorType, TTI::CastContextHint::None,
+          CostKind);
+
+    // TODO: add cost of extracting subvectors from the source vector that
+    // is to be partially reduced.
+    InstructionCost ReductionOpCost =
+        Ratio * getArithmeticInstrCost(Opcode, AccumVectorType, CostKind);
+
+    if (!BinOp)
+      return ExtendCostA + ReductionOpCost;
+
+    InstructionCost ExtendCostB = 0;
+    if (OpBExtend != TTI::PartialReductionExtendKind::PR_None)
+      ExtendCostB = getCastInstrCost(
+          TTI::getOpcodeForPartialReductionExtendKind(OpBExtend),
+          ExtInputVectorType, InputVectorType, TTI::CastContextHint::None,
+          CostKind);
+    return ExtendCostA + ExtendCostB + ReductionOpCost +
+           getArithmeticInstrCost(*BinOp, ExtInputVectorType, CostKind);
+  }
+
   InstructionCost getVectorSplitCost() const { return 1; }
 
   /// @}

diff  --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index de578ea29cbe9..734339e5c7a05 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -6056,34 +6056,14 @@ InstructionCost AArch64TTIImpl::getPartialReductionCost(
       return Cost * 2;
   }
 
-  // Returns cost of expanding the partial reduction in ISel.
-  auto GetExpandCost = [&]() -> InstructionCost {
-    Type *ExtVectorType =
-        VectorType::get(AccumVectorType->getElementType(), VF);
-    auto ExtendCostA = getCastInstrCost(
-        TTI::getOpcodeForPartialReductionExtendKind(OpAExtend), ExtVectorType,
-        InputVectorType, TTI::CastContextHint::None, CostKind);
-    auto RedOpCost =
-        Ratio * getArithmeticInstrCost(Opcode, AccumVectorType, CostKind);
-    if (!BinOp)
-      return ExtendCostA + RedOpCost;
-
-    auto ExtendCostB = getCastInstrCost(
-        TTI::getOpcodeForPartialReductionExtendKind(OpBExtend), ExtVectorType,
-        InputVectorType, TTI::CastContextHint::None, CostKind);
-    return ExtendCostA + ExtendCostB + RedOpCost +
-           getArithmeticInstrCost(*BinOp, ExtVectorType, CostKind);
-  };
-
-  if (IsSub) {
-    // Slightly lower the cost of a sub reduction so that it can be considered
-    // as candidate for 'cdot' operations. This is a somewhat arbitrary number,
-    // because we don't yet model these operations directly.
-    return (8 * GetExpandCost()) / 10;
-  }
-
-  // By default, assume the operation is expanded.
-  return GetExpandCost();
+  InstructionCost ExpandCost = BaseT::getPartialReductionCost(
+      Opcode, InputTypeA, InputTypeB, AccumType, VF, OpAExtend, OpBExtend,
+      BinOp, CostKind, FMF);
+
+  // Slightly lower the cost of a sub reduction so that it can be considered
+  // as candidate for 'cdot' operations. This is a somewhat arbitrary number,
+  // because we don't yet model these operations directly.
+  return ExpandCost.isValid() && IsSub ? ((8 * ExpandCost) / 10) : ExpandCost;
 }
 
 InstructionCost

diff  --git a/llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.h b/llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.h
index eb0a1f202412b..e49881dee57db 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.h
+++ b/llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.h
@@ -271,6 +271,15 @@ class GCNTTIImpl final : public BasicTTIImplBase<GCNTTIImpl> {
                              std::optional<FastMathFlags> FMF,
                              TTI::TargetCostKind CostKind) const override;
 
+  InstructionCost getPartialReductionCost(
+      unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
+      ElementCount VF, TTI::PartialReductionExtendKind OpAExtend,
+      TTI::PartialReductionExtendKind OpBExtend, std::optional<unsigned> BinOp,
+      TTI::TargetCostKind CostKind,
+      std::optional<FastMathFlags> FMF) const override {
+    return InstructionCost::getInvalid();
+  }
+
   InstructionCost
   getIntrinsicInstrCost(const IntrinsicCostAttributes &ICA,
                         TTI::TargetCostKind CostKind) const override;

diff  --git a/llvm/lib/Target/ARM/ARMTargetTransformInfo.h b/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
index f766deb884e0b..0d6d5d202bddf 100644
--- a/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
+++ b/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
@@ -413,6 +413,15 @@ class ARMTTIImpl final : public BasicTTIImplBase<ARMTTIImpl> {
   getIntrinsicInstrCost(const IntrinsicCostAttributes &ICA,
                         TTI::TargetCostKind CostKind) const override;
 
+  InstructionCost getPartialReductionCost(
+      unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
+      ElementCount VF, TTI::PartialReductionExtendKind OpAExtend,
+      TTI::PartialReductionExtendKind OpBExtend, std::optional<unsigned> BinOp,
+      TTI::TargetCostKind CostKind,
+      std::optional<FastMathFlags> FMF) const override {
+    return InstructionCost::getInvalid();
+  }
+
   /// getScalingFactorCost - Return the cost of the scaling used in
   /// addressing mode represented by AM.
   /// If the AM is supported, the return value must be >= 0.

diff  --git a/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.h b/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.h
index d96036067c786..456604ef9f627 100644
--- a/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.h
+++ b/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.h
@@ -101,6 +101,16 @@ class SystemZTTIImpl final : public BasicTTIImplBase<SystemZTTIImpl> {
       TTI::OperandValueInfo Op2Info = {TTI::OK_AnyValue, TTI::OP_None},
       ArrayRef<const Value *> Args = {},
       const Instruction *CxtI = nullptr) const override;
+
+  InstructionCost getPartialReductionCost(
+      unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
+      ElementCount VF, TTI::PartialReductionExtendKind OpAExtend,
+      TTI::PartialReductionExtendKind OpBExtend, std::optional<unsigned> BinOp,
+      TTI::TargetCostKind CostKind,
+      std::optional<FastMathFlags> FMF) const override {
+    return InstructionCost::getInvalid();
+  }
+
   InstructionCost
   getShuffleCost(TTI::ShuffleKind Kind, VectorType *DstTy, VectorType *SrcTy,
                  ArrayRef<int> Mask, TTI::TargetCostKind CostKind, int Index,

diff  --git a/llvm/lib/Target/X86/X86TargetTransformInfo.h b/llvm/lib/Target/X86/X86TargetTransformInfo.h
index dbab8edfc6a2c..ee1afdd8aaea9 100644
--- a/llvm/lib/Target/X86/X86TargetTransformInfo.h
+++ b/llvm/lib/Target/X86/X86TargetTransformInfo.h
@@ -224,6 +224,15 @@ class X86TTIImpl final : public BasicTTIImplBase<X86TTIImpl> {
                              std::optional<FastMathFlags> FMF,
                              TTI::TargetCostKind CostKind) const override;
 
+  InstructionCost getPartialReductionCost(
+      unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
+      ElementCount VF, TTI::PartialReductionExtendKind OpAExtend,
+      TTI::PartialReductionExtendKind OpBExtend, std::optional<unsigned> BinOp,
+      TTI::TargetCostKind CostKind,
+      std::optional<FastMathFlags> FMF) const override {
+    return InstructionCost::getInvalid();
+  }
+
   InstructionCost getMinMaxCost(Intrinsic::ID IID, Type *Ty,
                                 TTI::TargetCostKind CostKind,
                                 FastMathFlags FMF) const;


        


More information about the llvm-commits mailing list