[llvm] [TTI] Introduce utilities for target costing of build & explode vector [NFC] (PR #85455)

Philip Reames via llvm-commits llvm-commits at lists.llvm.org
Fri Mar 15 12:39:06 PDT 2024


https://github.com/preames created https://github.com/llvm/llvm-project/pull/85455

Introduce utilities for costing build vector and explode vector operations inside the TTI target implementation logic.  As can be seen these are by far the most common operations actually performed.

In case the goal isn't clear here, I plan to eliminate getScalarizationOverhead from the TTI interface layer.  All of our targets cost a combined insert and extract as equivalent to a explode vector followed by a build vector so the combined interface can be killed off.

This is the inverse of https://github.com/llvm/llvm-project/pull/85421. Once both patches land, only the actual meat of the change remains.

One subtlety here - we have to be very careful to make sure we're calling the directly analogous cover function.  We've got a base class and subclass involved here, and it's important at times whether we call a method on the subclass or the base class.  This is harder to follow since we have multiple getScalarizationOverhead variants with different signatures - most of which only exist on the base class, but some (not all) of which proxy back to the sub-class.

>From 4628b586b398b863bb78f5acde48fcb719f5d41b Mon Sep 17 00:00:00 2001
From: Philip Reames <preames at rivosinc.com>
Date: Fri, 15 Mar 2024 09:57:29 -0700
Subject: [PATCH] [TTI] Introduce utilities for target costing build & explode
 vector

Introduce utilities for costing build vector and explode vector
operations inside the TTI target implementation logic.  As can be seen
these are by far the most common operations actually performed.

In case the goal isn't clear here, I plan to eliminate
getScalarizationOverhead from the TTI interface layer.  All of our
targets cost a combined insert and extract as equivalent to a
explode vector followed by a build vector so the combined interface
can be killed off.

This is the inverse of https://github.com/llvm/llvm-project/pull/85421. Once both patches land, only the actual meat of the change remains.

One subtlety here - we have to be very careful to make sure we're
calling the directly analogous cover function.  We've got a base
class and subclass involved here, and it's important at times
whether we call a method on the subclass or the base class.  This is
harder to follow since we have multiple getScalarizationOverhead
variants with different signatures - most of which only exist on the
base class, but some (not all) of which proxy back to the sub-class.
---
 llvm/include/llvm/CodeGen/BasicTTIImpl.h      | 61 +++++++++++--------
 .../lib/Target/ARM/ARMTargetTransformInfo.cpp | 15 ++---
 .../SystemZ/SystemZTargetTransformInfo.cpp    | 14 ++---
 .../lib/Target/X86/X86TargetTransformInfo.cpp | 12 ++--
 4 files changed, 52 insertions(+), 50 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
index 61f6564e8cd79b..18e0896650f0dd 100644
--- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h
+++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
@@ -782,6 +782,20 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
     return Cost;
   }
 
+  InstructionCost getBuildVectorCost(VectorType *InTy,
+                                     const APInt &DemandedElts,
+                                     TTI::TargetCostKind CostKind) {
+    return getScalarizationOverhead(InTy, DemandedElts, /*Insert=*/ true,
+                                    /*Extract=*/ false, CostKind);
+  }
+
+  InstructionCost getExplodeVectorCost(VectorType *InTy,
+                                       const APInt &DemandedElts,
+                                       TTI::TargetCostKind CostKind) {
+    return getScalarizationOverhead(InTy, DemandedElts, /*Insert=*/ false,
+                                    /*Extract=*/ true, CostKind);
+  }
+
   /// Helper wrapper for the DemandedElts variant of getScalarizationOverhead.
   InstructionCost getScalarizationOverhead(VectorType *InTy, bool Insert,
                                            bool Extract,
@@ -795,6 +809,18 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
                                              CostKind);
   }
 
+  InstructionCost getBuildVectorCost(VectorType *InTy,
+                                     TTI::TargetCostKind CostKind) {
+    return getScalarizationOverhead(InTy, /*Insert=*/ true, /*Extract=*/ false,
+                                    CostKind);
+  }
+
+  InstructionCost getExplodeVectorCost(VectorType *InTy,
+                                       TTI::TargetCostKind CostKind) {
+    return getScalarizationOverhead(InTy, /*Insert=*/ false, /*Extract=*/ true,
+                                    CostKind);
+  }
+
   /// Estimate the overhead of scalarizing an instructions unique
   /// non-constant operands. The (potentially vector) types to use for each of
   /// argument are passes via Tys.
@@ -816,8 +842,7 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
 
       if (!isa<Constant>(A) && UniqueOperands.insert(A).second) {
         if (auto *VecTy = dyn_cast<VectorType>(Ty))
-          Cost += getScalarizationOverhead(VecTy, /*Insert*/ false,
-                                           /*Extract*/ true, CostKind);
+          Cost += getExplodeVectorCost(VecTy, CostKind);
       }
     }
 
@@ -1186,12 +1211,8 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
     //  that the conversion is scalarized in one way or another.
     if (Opcode == Instruction::BitCast) {
       // Illegal bitcasts are done by storing and loading from a stack slot.
-      return (SrcVTy ? getScalarizationOverhead(SrcVTy, /*Insert*/ false,
-                                                /*Extract*/ true, CostKind)
-                     : 0) +
-             (DstVTy ? getScalarizationOverhead(DstVTy, /*Insert*/ true,
-                                                /*Extract*/ false, CostKind)
-                     : 0);
+      return (SrcVTy ? getExplodeVectorCost(SrcVTy, CostKind) : 0) +
+             (DstVTy ? getBuildVectorCost(DstVTy, CostKind) : 0);
     }
 
     llvm_unreachable("Unhandled cast");
@@ -1254,9 +1275,7 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
 
       // Return the cost of multiple scalar invocation plus the cost of
       // inserting and extracting the values.
-      return getScalarizationOverhead(ValVTy, /*Insert*/ true,
-                                      /*Extract*/ false, CostKind) +
-             Num * Cost;
+      return getBuildVectorCost(ValVTy, CostKind) + Num * Cost;
     }
 
     // Unknown scalar opcode.
@@ -1821,9 +1840,8 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
     if (RetVF.isVector() && !RetVF.isScalable()) {
       ScalarizationCost = 0;
       if (!RetTy->isVoidTy())
-        ScalarizationCost += getScalarizationOverhead(
-            cast<VectorType>(RetTy),
-            /*Insert*/ true, /*Extract*/ false, CostKind);
+        ScalarizationCost += getBuildVectorCost(
+            cast<VectorType>(RetTy), CostKind);
       ScalarizationCost +=
           getOperandsScalarizationOverhead(Args, ICA.getArgTypes(), CostKind);
     }
@@ -1877,8 +1895,7 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
       Type *ScalarRetTy = RetTy;
       if (auto *RetVTy = dyn_cast<VectorType>(RetTy)) {
         if (!SkipScalarizationCost)
-          ScalarizationCost = getScalarizationOverhead(
-              RetVTy, /*Insert*/ true, /*Extract*/ false, CostKind);
+          ScalarizationCost = getBuildVectorCost(RetVTy, CostKind);
         ScalarCalls = std::max(ScalarCalls,
                                cast<FixedVectorType>(RetVTy)->getNumElements());
         ScalarRetTy = RetTy->getScalarType();
@@ -1888,8 +1905,7 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
         Type *Ty = Tys[i];
         if (auto *VTy = dyn_cast<VectorType>(Ty)) {
           if (!SkipScalarizationCost)
-            ScalarizationCost += getScalarizationOverhead(
-                VTy, /*Insert*/ false, /*Extract*/ true, CostKind);
+            ScalarizationCost += getExplodeVectorCost(VTy, CostKind);
           ScalarCalls = std::max(ScalarCalls,
                                  cast<FixedVectorType>(VTy)->getNumElements());
           Ty = Ty->getScalarType();
@@ -2299,8 +2315,7 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
       InstructionCost ScalarizationCost =
           SkipScalarizationCost
               ? ScalarizationCostPassed
-              : getScalarizationOverhead(RetVTy, /*Insert*/ true,
-                                         /*Extract*/ false, CostKind);
+              : getBuildVectorCost(RetVTy, CostKind);
 
       unsigned ScalarCalls = cast<FixedVectorType>(RetVTy)->getNumElements();
       SmallVector<Type *, 4> ScalarTys;
@@ -2316,8 +2331,7 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
       for (unsigned i = 0, ie = Tys.size(); i != ie; ++i) {
         if (auto *VTy = dyn_cast<VectorType>(Tys[i])) {
           if (!ICA.skipScalarizationCost())
-            ScalarizationCost += getScalarizationOverhead(
-                VTy, /*Insert*/ false, /*Extract*/ true, CostKind);
+            ScalarizationCost += getExplodeVectorCost(VTy, CostKind);
           ScalarCalls = std::max(ScalarCalls,
                                  cast<FixedVectorType>(VTy)->getNumElements());
         }
@@ -2462,8 +2476,7 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
       return InstructionCost::getInvalid();
 
     auto *VTy = cast<FixedVectorType>(Ty);
-    InstructionCost ExtractCost = getScalarizationOverhead(
-        VTy, /*Insert=*/false, /*Extract=*/true, CostKind);
+    InstructionCost ExtractCost = getExplodeVectorCost(VTy, CostKind);
     InstructionCost ArithCost = thisT()->getArithmeticInstrCost(
         Opcode, VTy->getElementType(), CostKind);
     ArithCost *= VTy->getNumElements();
diff --git a/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp b/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
index 3be894ad3bef2c..1e0bac8e540c2e 100644
--- a/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
+++ b/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
@@ -1025,10 +1025,8 @@ InstructionCost ARMTTIImpl::getCmpSelInstrCost(unsigned Opcode, Type *ValTy,
     if (Opcode == Instruction::FCmp && !ST->hasMVEFloatOps()) {
       // One scalaization insert, one scalarization extract and the cost of the
       // fcmps.
-      return BaseT::getScalarizationOverhead(VecValTy, /*Insert*/ false,
-                                             /*Extract*/ true, CostKind) +
-             BaseT::getScalarizationOverhead(VecCondTy, /*Insert*/ true,
-                                             /*Extract*/ false, CostKind) +
+      return BaseT::getExplodeVectorCost(VecValTy, CostKind) +
+             BaseT::getBuildVectorCost(VecCondTy, CostKind) +
              VecValTy->getNumElements() *
                  getCmpSelInstrCost(Opcode, ValTy->getScalarType(),
                                     VecCondTy->getScalarType(), VecPred,
@@ -1045,8 +1043,7 @@ InstructionCost ARMTTIImpl::getCmpSelInstrCost(unsigned Opcode, Type *ValTy,
     if (LT.second.isVector() && LT.second.getVectorNumElements() > 2) {
       if (LT.first > 1)
         return LT.first * BaseCost +
-               BaseT::getScalarizationOverhead(VecCondTy, /*Insert*/ true,
-                                               /*Extract*/ false, CostKind);
+               BaseT::getBuildVectorCost(VecCondTy, CostKind);
       return BaseCost;
     }
   }
@@ -1599,10 +1596,8 @@ InstructionCost ARMTTIImpl::getGatherScatterOpCost(
   // greatly increasing the cost.
   InstructionCost ScalarCost =
       NumElems * LT.first + (VariableMask ? NumElems * 5 : 0) +
-      BaseT::getScalarizationOverhead(VTy, /*Insert*/ true, /*Extract*/ false,
-                                      CostKind) +
-      BaseT::getScalarizationOverhead(VTy, /*Insert*/ false, /*Extract*/ true,
-                                      CostKind);
+      BaseT::getBuildVectorCost(VTy, CostKind) +
+      BaseT::getExplodeVectorCost(VTy, CostKind);
 
   if (EltSize < 8 || Alignment < EltSize / 8)
     return ScalarCost;
diff --git a/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp b/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp
index e4adb7be564952..69ed6ae8d33597 100644
--- a/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp
+++ b/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp
@@ -910,10 +910,10 @@ InstructionCost SystemZTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
           (Opcode == Instruction::FPToSI || Opcode == Instruction::FPToUI))
         NeedsExtracts = false;
 
-      TotCost += getScalarizationOverhead(SrcVecTy, /*Insert*/ false,
-                                          NeedsExtracts, CostKind);
-      TotCost += getScalarizationOverhead(DstVecTy, NeedsInserts,
-                                          /*Extract*/ false, CostKind);
+      if (NeedsExtracts)
+        TotCost += getExplodeVectorCost(SrcVecTy, CostKind);
+      if (NeedsInserts)
+        TotCost += getBuildVectorCost(DstVecTy, CostKind);
 
       // FIXME: VF 2 for float<->i32 is currently just as expensive as for VF 4.
       if (VF == 2 && SrcScalarBits == 32 && DstScalarBits == 32)
@@ -925,8 +925,7 @@ InstructionCost SystemZTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
     if (Opcode == Instruction::FPTrunc) {
       if (SrcScalarBits == 128)  // fp128 -> double/float + inserts of elements.
         return VF /*ldxbr/lexbr*/ +
-               getScalarizationOverhead(DstVecTy, /*Insert*/ true,
-                                        /*Extract*/ false, CostKind);
+               getBuildVectorCost(DstVecTy, CostKind);
       else // double -> float
         return VF / 2 /*vledb*/ + std::max(1U, VF / 4 /*vperm*/);
     }
@@ -939,8 +938,7 @@ InstructionCost SystemZTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
         return VF * 2;
       }
       // -> fp128.  VF * lxdb/lxeb + extraction of elements.
-      return VF + getScalarizationOverhead(SrcVecTy, /*Insert*/ false,
-                                           /*Extract*/ true, CostKind);
+      return VF + getExplodeVectorCost(SrcVecTy, CostKind);
     }
   }
 
diff --git a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
index d336ab9d309c4e..e0ab99ae729166 100644
--- a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
+++ b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
@@ -4530,8 +4530,7 @@ X86TTIImpl::getScalarizationOverhead(VectorType *Ty, const APInt &DemandedElts,
       // For types we can insert directly, insertion into 128-bit sub vectors is
       // cheap, followed by a cheap chain of concatenations.
       if (LegalVectorBitWidth <= LaneBitWidth) {
-        Cost += BaseT::getScalarizationOverhead(Ty, DemandedElts, Insert,
-                                                /*Extract*/ false, CostKind);
+        Cost += BaseT::getBuildVectorCost(Ty, DemandedElts, CostKind);
       } else {
         // In each 128-lane, if at least one index is demanded but not all
         // indices are demanded and this 128-lane is not the first 128-lane of
@@ -4570,8 +4569,7 @@ X86TTIImpl::getScalarizationOverhead(VectorType *Ty, const APInt &DemandedElts,
           if (!LaneEltMask.isAllOnes())
             Cost += getShuffleCost(TTI::SK_ExtractSubvector, Ty, std::nullopt,
                                    CostKind, I * NumEltsPerLane, LaneTy);
-          Cost += BaseT::getScalarizationOverhead(LaneTy, LaneEltMask, Insert,
-                                                  /*Extract*/ false, CostKind);
+          Cost += BaseT::getBuildVectorCost(LaneTy, LaneEltMask, CostKind);
         }
 
         APInt AffectedLanes =
@@ -4648,8 +4646,7 @@ X86TTIImpl::getScalarizationOverhead(VectorType *Ty, const APInt &DemandedElts,
             continue;
           Cost += getShuffleCost(TTI::SK_ExtractSubvector, Ty, std::nullopt,
                                  CostKind, I * NumEltsPerLane, LaneTy);
-          Cost += BaseT::getScalarizationOverhead(
-              LaneTy, LaneEltMask, /*Insert*/ false, Extract, CostKind);
+          Cost += BaseT::getExplodeVectorCost(LaneTy, LaneEltMask, CostKind);
         }
 
         return Cost;
@@ -4657,8 +4654,7 @@ X86TTIImpl::getScalarizationOverhead(VectorType *Ty, const APInt &DemandedElts,
     }
 
     // Fallback to default extraction.
-    Cost += BaseT::getScalarizationOverhead(Ty, DemandedElts, /*Insert*/ false,
-                                            Extract, CostKind);
+    Cost += BaseT::getExplodeVectorCost(Ty, DemandedElts, CostKind);
   }
 
   return Cost;



More information about the llvm-commits mailing list