[llvm] f8efc5c - [NFC][TTI] Add/extract `getReplicationShuffleCost()` method, deduplicate it's implementations

Roman Lebedev via llvm-commits llvm-commits at lists.llvm.org
Sat Nov 6 06:55:39 PDT 2021


Author: Roman Lebedev
Date: 2021-11-06T16:45:15+03:00
New Revision: f8efc5c0ac68d2f94c8f83e65798e786e2c8c8cd

URL: https://github.com/llvm/llvm-project/commit/f8efc5c0ac68d2f94c8f83e65798e786e2c8c8cd
DIFF: https://github.com/llvm/llvm-project/commit/f8efc5c0ac68d2f94c8f83e65798e786e2c8c8cd.diff

LOG: [NFC][TTI] Add/extract `getReplicationShuffleCost()` method, deduplicate it's implementations

Hiding it in `getInterleavedMemoryOpCost()` is problematic for a number of reasons,
including testability and reuse, let's do better.

In a followup `getUserCost()` will be taught to use to to estimate the mask costs,
which will allow for better cost model tests for it.

Reviewed By: RKSimon

Differential Revision: https://reviews.llvm.org/D113313

Added: 
    

Modified: 
    llvm/include/llvm/Analysis/TargetTransformInfo.h
    llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
    llvm/include/llvm/CodeGen/BasicTTIImpl.h
    llvm/lib/Analysis/TargetTransformInfo.cpp
    llvm/lib/Target/X86/X86TargetTransformInfo.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 370ab30726848..10cc0518e3d74 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -1113,6 +1113,17 @@ class TargetTransformInfo {
   InstructionCost getVectorInstrCost(unsigned Opcode, Type *Val,
                                      unsigned Index = -1) const;
 
+  /// \return The cost of replication shuffle of \p VF elements typed \p EltTy
+  /// \p ReplicationFactor times.
+  ///
+  /// For example, the mask for \p ReplicationFactor=3 and \p VF=4 is:
+  ///   <0,0,0,1,1,1,2,2,2,3,3,3>
+  InstructionCost getReplicationShuffleCost(Type *EltTy, int ReplicationFactor,
+                                            int VF,
+                                            const APInt &DemandedSrcElts,
+                                            const APInt &DemandedReplicatedElts,
+                                            TTI::TargetCostKind CostKind);
+
   /// \return The cost of Load and Store instructions.
   InstructionCost
   getMemoryOpCost(unsigned Opcode, Type *Src, Align Alignment,
@@ -1636,6 +1647,11 @@ class TargetTransformInfo::Concept {
                                              const Instruction *I) = 0;
   virtual InstructionCost getVectorInstrCost(unsigned Opcode, Type *Val,
                                              unsigned Index) = 0;
+
+  virtual InstructionCost getReplicationShuffleCost(
+      Type *EltTy, int ReplicationFactor, int VF, const APInt &DemandedSrcElts,
+      const APInt &DemandedReplicatedElts, TTI::TargetCostKind CostKind) = 0;
+
   virtual InstructionCost getMemoryOpCost(unsigned Opcode, Type *Src,
                                           Align Alignment,
                                           unsigned AddressSpace,
@@ -2137,6 +2153,15 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
                                      unsigned Index) override {
     return Impl.getVectorInstrCost(Opcode, Val, Index);
   }
+  InstructionCost
+  getReplicationShuffleCost(Type *EltTy, int ReplicationFactor, int VF,
+                            const APInt &DemandedSrcElts,
+                            const APInt &DemandedReplicatedElts,
+                            TTI::TargetCostKind CostKind) override {
+    return Impl.getReplicationShuffleCost(EltTy, ReplicationFactor, VF,
+                                          DemandedSrcElts,
+                                          DemandedReplicatedElts, CostKind);
+  }
   InstructionCost getMemoryOpCost(unsigned Opcode, Type *Src, Align Alignment,
                                   unsigned AddressSpace,
                                   TTI::TargetCostKind CostKind,

diff  --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 07344fc05036c..d62e74db6a1a0 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -544,6 +544,13 @@ class TargetTransformInfoImplBase {
     return 1;
   }
 
+  unsigned getReplicationShuffleCost(Type *EltTy, int ReplicationFactor, int VF,
+                                     const APInt &DemandedSrcElts,
+                                     const APInt &DemandedReplicatedElts,
+                                     TTI::TargetCostKind CostKind) {
+    return 1;
+  }
+
   InstructionCost getMemoryOpCost(unsigned Opcode, Type *Src, Align Alignment,
                                   unsigned AddressSpace,
                                   TTI::TargetCostKind CostKind,

diff  --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
index 1ebab920b82fa..5dad685191927 100644
--- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h
+++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
@@ -1113,6 +1113,36 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
     return LT.first;
   }
 
+  InstructionCost getReplicationShuffleCost(Type *EltTy, int ReplicationFactor,
+                                            int VF,
+                                            const APInt &DemandedSrcElts,
+                                            const APInt &DemandedReplicatedElts,
+                                            TTI::TargetCostKind CostKind) {
+    InstructionCost Cost;
+
+    auto *SrcVT = FixedVectorType::get(EltTy, VF);
+    auto *ReplicatedVT = FixedVectorType::get(EltTy, VF * ReplicationFactor);
+
+    // The Mask shuffling cost is extract all the elements of the Mask
+    // and insert each of them Factor times into the wide vector:
+    //
+    // E.g. an interleaved group with factor 3:
+    //    %mask = icmp ult <8 x i32> %vec1, %vec2
+    //    %interleaved.mask = shufflevector <8 x i1> %mask, <8 x i1> undef,
+    //        <24 x i32> <0,0,0,1,1,1,2,2,2,3,3,3,4,4,4,5,5,5,6,6,6,7,7,7>
+    // The cost is estimated as extract all mask elements from the <8xi1> mask
+    // vector and insert them factor times into the <24xi1> shuffled mask
+    // vector.
+    Cost += thisT()->getScalarizationOverhead(SrcVT, DemandedSrcElts,
+                                              /*Insert*/ false,
+                                              /*Extract*/ true);
+    Cost +=
+        thisT()->getScalarizationOverhead(ReplicatedVT, DemandedReplicatedElts,
+                                          /*Insert*/ true, /*Extract*/ false);
+
+    return Cost;
+  }
+
   InstructionCost getMemoryOpCost(unsigned Opcode, Type *Src,
                                   MaybeAlign Alignment, unsigned AddressSpace,
                                   TTI::TargetCostKind CostKind,
@@ -1292,34 +1322,22 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
       return Cost;
 
     Type *I8Type = Type::getInt8Ty(VT->getContext());
-    auto *MaskVT = FixedVectorType::get(I8Type, NumElts);
-    SubVT = FixedVectorType::get(I8Type, NumSubElts);
 
-    // The Mask shuffling cost is extract all the elements of the Mask
-    // and insert each of them Factor times into the wide vector:
-    //
-    // E.g. an interleaved group with factor 3:
-    //    %mask = icmp ult <8 x i32> %vec1, %vec2
-    //    %interleaved.mask = shufflevector <8 x i1> %mask, <8 x i1> undef,
-    //        <24 x i32> <0,0,0,1,1,1,2,2,2,3,3,3,4,4,4,5,5,5,6,6,6,7,7,7>
-    // The cost is estimated as extract all mask elements from the <8xi1> mask
-    // vector and insert them factor times into the <24xi1> shuffled mask
-    // vector.
-    Cost +=
-        thisT()->getScalarizationOverhead(SubVT, DemandedAllSubElts,
-                                          /*Insert*/ false, /*Extract*/ true);
-    Cost += thisT()->getScalarizationOverhead(
-        MaskVT, UseMaskForGaps ? DemandedLoadStoreElts : DemandedAllResultElts,
-        /*Insert*/ true, /*Extract*/ false);
+    Cost += thisT()->getReplicationShuffleCost(
+        I8Type, Factor, NumSubElts, DemandedAllSubElts,
+        UseMaskForGaps ? DemandedLoadStoreElts : DemandedAllResultElts,
+        CostKind);
 
     // The Gaps mask is invariant and created outside the loop, therefore the
     // cost of creating it is not accounted for here. However if we have both
     // a MaskForGaps and some other mask that guards the execution of the
     // memory access, we need to account for the cost of And-ing the two masks
     // inside the loop.
-    if (UseMaskForGaps)
+    if (UseMaskForGaps) {
+      auto *MaskVT = FixedVectorType::get(I8Type, NumElts);
       Cost += thisT()->getArithmeticInstrCost(BinaryOperator::And, MaskVT,
                                               CostKind);
+    }
 
     return Cost;
   }

diff  --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index 800896fa6a053..1fe9c3a0f0cfe 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -824,6 +824,15 @@ InstructionCost TargetTransformInfo::getVectorInstrCost(unsigned Opcode,
   return Cost;
 }
 
+InstructionCost TargetTransformInfo::getReplicationShuffleCost(
+    Type *EltTy, int ReplicationFactor, int VF, const APInt &DemandedSrcElts,
+    const APInt &DemandedReplicatedElts, TTI::TargetCostKind CostKind) {
+  InstructionCost Cost = TTIImpl->getReplicationShuffleCost(
+      EltTy, ReplicationFactor, VF, DemandedSrcElts, DemandedReplicatedElts,
+      CostKind);
+  assert(Cost >= 0 && "TTI should not produce negative costs!");
+  return Cost;
+}
 InstructionCost TargetTransformInfo::getMemoryOpCost(
     unsigned Opcode, Type *Src, Align Alignment, unsigned AddressSpace,
     TTI::TargetCostKind CostKind, const Instruction *I) const {

diff  --git a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
index e13bdfd314e3b..ebde29f0ba6e6 100644
--- a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
+++ b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
@@ -5075,36 +5075,22 @@ InstructionCost X86TTIImpl::getInterleavedMemoryOpCostAVX512(
     }
 
     Type *I8Type = Type::getInt8Ty(VecTy->getContext());
-    auto *MaskVT = FixedVectorType::get(I8Type, VecTy->getNumElements());
-    auto *MaskSubVT = FixedVectorType::get(I8Type, VF);
-
-    // The Mask shuffling cost is extract all the elements of the Mask
-    // and insert each of them Factor times into the wide vector:
-    //
-    // E.g. an interleaved group with factor 3:
-    //    %mask = icmp ult <8 x i32> %vec1, %vec2
-    //    %interleaved.mask = shufflevector <8 x i1> %mask, <8 x i1> undef,
-    //        <24 x i32> <0,0,0,1,1,1,2,2,2,3,3,3,4,4,4,5,5,5,6,6,6,7,7,7>
-    // The cost is estimated as extract all mask elements from the <8xi1> mask
-    // vector and insert them factor times into the <24xi1> shuffled mask
-    // vector.
-    MaskCost += getScalarizationOverhead(
-        MaskSubVT, APInt::getAllOnes(MaskSubVT->getNumElements()),
-        /*Insert*/ false, /*Extract*/ true);
-    MaskCost += getScalarizationOverhead(
-        MaskVT,
+
+    MaskCost = getReplicationShuffleCost(
+        I8Type, Factor, VF, APInt::getAllOnes(VF),
         UseMaskForGaps ? DemandedLoadStoreElts
                        : APInt::getAllOnes(VecTy->getNumElements()),
-        /*Insert*/ true,
-        /*Extract*/ false);
+        CostKind);
 
     // The Gaps mask is invariant and created outside the loop, therefore the
     // cost of creating it is not accounted for here. However if we have both
     // a MaskForGaps and some other mask that guards the execution of the
     // memory access, we need to account for the cost of And-ing the two masks
     // inside the loop.
-    if (UseMaskForGaps)
+    if (UseMaskForGaps) {
+      auto *MaskVT = FixedVectorType::get(I8Type, VecTy->getNumElements());
       MaskCost += getArithmeticInstrCost(BinaryOperator::And, MaskVT, CostKind);
+    }
   }
 
   if (Opcode == Instruction::Load) {


        


More information about the llvm-commits mailing list