[llvm] d484cc1 - [TTI] Adjust `getReplicationShuffleCost()` interface

Roman Lebedev via llvm-commits llvm-commits at lists.llvm.org
Tue Nov 9 03:08:21 PST 2021


Author: Roman Lebedev
Date: 2021-11-09T14:07:59+03:00
New Revision: d484cc152b1d9282230a17a218337342d52536e2

URL: https://github.com/llvm/llvm-project/commit/d484cc152b1d9282230a17a218337342d52536e2
DIFF: https://github.com/llvm/llvm-project/commit/d484cc152b1d9282230a17a218337342d52536e2.diff

LOG: [TTI] Adjust `getReplicationShuffleCost()` interface

It is trivial to produce DemandedSrcElts given DemandedReplicatedElts,
so don't pass the former. Also, it isn't really useful so far
to have the overload taking the Mask, so just inline it.

Added: 
    

Modified: 
    llvm/include/llvm/Analysis/TargetTransformInfo.h
    llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
    llvm/include/llvm/CodeGen/BasicTTIImpl.h
    llvm/lib/Analysis/TargetTransformInfo.cpp
    llvm/lib/Target/X86/X86TargetTransformInfo.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 4312c2ae0de6..e93a1e2b7aaf 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -1127,12 +1127,8 @@ class TargetTransformInfo {
   ///   <0,0,0,1,1,1,2,2,2,3,3,3>
   InstructionCost getReplicationShuffleCost(Type *EltTy, int ReplicationFactor,
                                             int VF,
-                                            const APInt &DemandedSrcElts,
                                             const APInt &DemandedReplicatedElts,
                                             TTI::TargetCostKind CostKind);
-  InstructionCost getReplicationShuffleCost(Type *EltTy, int ReplicationFactor,
-                                            int VF, ArrayRef<int> Mask,
-                                            TTI::TargetCostKind CostKind);
 
   /// \return The cost of Load and Store instructions.
   InstructionCost
@@ -1661,12 +1657,9 @@ class TargetTransformInfo::Concept {
   virtual InstructionCost getVectorInstrCost(unsigned Opcode, Type *Val,
                                              unsigned Index) = 0;
 
-  virtual InstructionCost getReplicationShuffleCost(
-      Type *EltTy, int ReplicationFactor, int VF, const APInt &DemandedSrcElts,
-      const APInt &DemandedReplicatedElts, TTI::TargetCostKind CostKind) = 0;
   virtual InstructionCost
   getReplicationShuffleCost(Type *EltTy, int ReplicationFactor, int VF,
-                            ArrayRef<int> Mask,
+                            const APInt &DemandedReplicatedElts,
                             TTI::TargetCostKind CostKind) = 0;
 
   virtual InstructionCost getMemoryOpCost(unsigned Opcode, Type *Src,
@@ -2180,20 +2173,11 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
   }
   InstructionCost
   getReplicationShuffleCost(Type *EltTy, int ReplicationFactor, int VF,
-                            const APInt &DemandedSrcElts,
                             const APInt &DemandedReplicatedElts,
                             TTI::TargetCostKind CostKind) override {
     return Impl.getReplicationShuffleCost(EltTy, ReplicationFactor, VF,
-                                          DemandedSrcElts,
                                           DemandedReplicatedElts, CostKind);
   }
-  InstructionCost
-  getReplicationShuffleCost(Type *EltTy, int ReplicationFactor, int VF,
-                            ArrayRef<int> Mask,
-                            TTI::TargetCostKind CostKind) override {
-    return Impl.getReplicationShuffleCost(EltTy, ReplicationFactor, VF, Mask,
-                                          CostKind);
-  }
   InstructionCost getMemoryOpCost(unsigned Opcode, Type *Src, Align Alignment,
                                   unsigned AddressSpace,
                                   TTI::TargetCostKind CostKind,

diff  --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 707912dd3873..a0bae8ed29a1 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -552,16 +552,10 @@ class TargetTransformInfoImplBase {
   }
 
   unsigned getReplicationShuffleCost(Type *EltTy, int ReplicationFactor, int VF,
-                                     const APInt &DemandedSrcElts,
                                      const APInt &DemandedReplicatedElts,
                                      TTI::TargetCostKind CostKind) {
     return 1;
   }
-  unsigned getReplicationShuffleCost(Type *EltTy, int ReplicationFactor, int VF,
-                                     ArrayRef<int> Mask,
-                                     TTI::TargetCostKind CostKind) {
-    return 1;
-  }
 
   InstructionCost getMemoryOpCost(unsigned Opcode, Type *Src, Align Alignment,
                                   unsigned AddressSpace,
@@ -1119,10 +1113,17 @@ class TargetTransformInfoImplCRTPBase : public TargetTransformInfoImplBase {
               FixedVectorType::get(VecTy->getScalarType(), NumSubElts));
 
         int ReplicationFactor, VF;
-        if (Shuffle->isReplicationMask(ReplicationFactor, VF))
+        if (Shuffle->isReplicationMask(ReplicationFactor, VF)) {
+          APInt DemandedReplicatedElts =
+              APInt::getNullValue(Shuffle->getShuffleMask().size());
+          for (auto I : enumerate(Shuffle->getShuffleMask())) {
+            if (I.value() != UndefMaskElem)
+              DemandedReplicatedElts.setBit(I.index());
+          }
           return TargetTTI->getReplicationShuffleCost(
               VecSrcTy->getElementType(), ReplicationFactor, VF,
-              Shuffle->getShuffleMask(), CostKind);
+              DemandedReplicatedElts, CostKind);
+        }
 
         return CostKind == TTI::TCK_RecipThroughput ? -1 : 1;
       }

diff  --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
index aeefa01e2ff0..aeea6e459eac 100644
--- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h
+++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
@@ -1121,9 +1121,12 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
 
   InstructionCost getReplicationShuffleCost(Type *EltTy, int ReplicationFactor,
                                             int VF,
-                                            const APInt &DemandedSrcElts,
                                             const APInt &DemandedReplicatedElts,
                                             TTI::TargetCostKind CostKind) {
+    assert(DemandedReplicatedElts.getBitWidth() ==
+               (unsigned)VF * ReplicationFactor &&
+           "Unexpected size of DemandedReplicatedElts.");
+
     InstructionCost Cost;
 
     auto *SrcVT = FixedVectorType::get(EltTy, VF);
@@ -1139,6 +1142,7 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
     // The cost is estimated as extract all mask elements from the <8xi1> mask
     // vector and insert them factor times into the <24xi1> shuffled mask
     // vector.
+    APInt DemandedSrcElts = APIntOps::ScaleBitMask(DemandedReplicatedElts, VF);
     Cost += thisT()->getScalarizationOverhead(SrcVT, DemandedSrcElts,
                                               /*Insert*/ false,
                                               /*Extract*/ true);
@@ -1149,41 +1153,6 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
     return Cost;
   }
 
-  InstructionCost getReplicationShuffleCost(Type *EltTy, int ReplicationFactor,
-                                            int VF, ArrayRef<int> Mask,
-                                            TTI::TargetCostKind CostKind) {
-    assert(Mask.size() == (unsigned)VF * ReplicationFactor && "Bad mask size.");
-
-    APInt DemandedSrcElts = APInt::getNullValue(VF);
-
-    ArrayRef<int> RemainingMask = Mask;
-    for (int i = 0; i < VF; i++) {
-      ArrayRef<int> CurrSubMask = RemainingMask.take_front(ReplicationFactor);
-      RemainingMask = RemainingMask.drop_front(CurrSubMask.size());
-
-      assert(all_of(CurrSubMask,
-                    [i](int MaskElt) {
-                      return MaskElt == UndefMaskElem || MaskElt == i;
-                    }) &&
-             "Not a replication mask.");
-
-      if (any_of(CurrSubMask,
-                 [](int MaskElt) { return MaskElt != UndefMaskElem; }))
-        DemandedSrcElts.setBit(i);
-    }
-    assert(RemainingMask.empty() && "Did not consume the entire mask?");
-
-    APInt DemandedReplicatedElts = APInt::getNullValue(Mask.size());
-    for (auto I : enumerate(Mask)) {
-      if (I.value() != UndefMaskElem)
-        DemandedReplicatedElts.setBit(I.index());
-    }
-
-    return thisT()->getReplicationShuffleCost(EltTy, ReplicationFactor, VF,
-                                              DemandedSrcElts,
-                                              DemandedReplicatedElts, CostKind);
-  }
-
   InstructionCost getMemoryOpCost(unsigned Opcode, Type *Src,
                                   MaybeAlign Alignment, unsigned AddressSpace,
                                   TTI::TargetCostKind CostKind,
@@ -1365,7 +1334,7 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
     Type *I8Type = Type::getInt8Ty(VT->getContext());
 
     Cost += thisT()->getReplicationShuffleCost(
-        I8Type, Factor, NumSubElts, DemandedAllSubElts,
+        I8Type, Factor, NumSubElts,
         UseMaskForGaps ? DemandedLoadStoreElts : DemandedAllResultElts,
         CostKind);
 

diff  --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index 8c5254dbf615..dcd015b36ee8 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -834,19 +834,10 @@ InstructionCost TargetTransformInfo::getVectorInstrCost(unsigned Opcode,
 }
 
 InstructionCost TargetTransformInfo::getReplicationShuffleCost(
-    Type *EltTy, int ReplicationFactor, int VF, const APInt &DemandedSrcElts,
+    Type *EltTy, int ReplicationFactor, int VF,
     const APInt &DemandedReplicatedElts, TTI::TargetCostKind CostKind) {
   InstructionCost Cost = TTIImpl->getReplicationShuffleCost(
-      EltTy, ReplicationFactor, VF, DemandedSrcElts, DemandedReplicatedElts,
-      CostKind);
-  assert(Cost >= 0 && "TTI should not produce negative costs!");
-  return Cost;
-}
-InstructionCost TargetTransformInfo::getReplicationShuffleCost(
-    Type *EltTy, int ReplicationFactor, int VF, ArrayRef<int> Mask,
-    TTI::TargetCostKind CostKind) {
-  InstructionCost Cost = TTIImpl->getReplicationShuffleCost(
-      EltTy, ReplicationFactor, VF, Mask, CostKind);
+      EltTy, ReplicationFactor, VF, DemandedReplicatedElts, CostKind);
   assert(Cost >= 0 && "TTI should not produce negative costs!");
   return Cost;
 }

diff  --git a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
index ebde29f0ba6e..83f4e0bce175 100644
--- a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
+++ b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
@@ -5077,7 +5077,7 @@ InstructionCost X86TTIImpl::getInterleavedMemoryOpCostAVX512(
     Type *I8Type = Type::getInt8Ty(VecTy->getContext());
 
     MaskCost = getReplicationShuffleCost(
-        I8Type, Factor, VF, APInt::getAllOnes(VF),
+        I8Type, Factor, VF,
         UseMaskForGaps ? DemandedLoadStoreElts
                        : APInt::getAllOnes(VecTy->getNumElements()),
         CostKind);


        


More information about the llvm-commits mailing list