[llvm] 889588e - [SLP] Refactoring isLegalBroadcastLoad() to use `ElementCount`.

Vasileios Porpodas via llvm-commits llvm-commits at lists.llvm.org
Thu Apr 21 10:20:43 PDT 2022


Author: Vasileios Porpodas
Date: 2022-04-21T10:19:00-07:00
New Revision: 889588ee978ccfde69f4846b2e05cf23a36ce13e

URL: https://github.com/llvm/llvm-project/commit/889588ee978ccfde69f4846b2e05cf23a36ce13e
DIFF: https://github.com/llvm/llvm-project/commit/889588ee978ccfde69f4846b2e05cf23a36ce13e.diff

LOG: [SLP] Refactoring isLegalBroadcastLoad() to use `ElementCount`.

Replacing `unsigned` with `ElementCount` in the argument of `isLegalBroadcastLoad()`.
This helps reduce the diff of a future SLP patch for AArch64.

Added: 
    

Modified: 
    llvm/include/llvm/Analysis/TargetTransformInfo.h
    llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
    llvm/lib/Analysis/TargetTransformInfo.cpp
    llvm/lib/Target/X86/X86TargetTransformInfo.cpp
    llvm/lib/Target/X86/X86TargetTransformInfo.h
    llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 7e2424d8eb044..8ac6485362b29 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -660,7 +660,7 @@ class TargetTransformInfo {
 
   /// \Returns true if the target supports broadcasting a load to a vector of
   /// type <NumElements x ElementTy>.
-  bool isLegalBroadcastLoad(Type *ElementTy, unsigned NumElements) const;
+  bool isLegalBroadcastLoad(Type *ElementTy, ElementCount NumElements) const;
 
   /// Return true if the target supports masked scatter.
   bool isLegalMaskedScatter(Type *DataType, Align Alignment) const;
@@ -1560,7 +1560,7 @@ class TargetTransformInfo::Concept {
   virtual bool isLegalNTStore(Type *DataType, Align Alignment) = 0;
   virtual bool isLegalNTLoad(Type *DataType, Align Alignment) = 0;
   virtual bool isLegalBroadcastLoad(Type *ElementTy,
-                                    unsigned NumElements) const = 0;
+                                    ElementCount NumElements) const = 0;
   virtual bool isLegalMaskedScatter(Type *DataType, Align Alignment) = 0;
   virtual bool isLegalMaskedGather(Type *DataType, Align Alignment) = 0;
   virtual bool forceScalarizeMaskedGather(VectorType *DataType,
@@ -1968,7 +1968,7 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
     return Impl.isLegalNTLoad(DataType, Alignment);
   }
   bool isLegalBroadcastLoad(Type *ElementTy,
-                            unsigned NumElements) const override {
+                            ElementCount NumElements) const override {
     return Impl.isLegalBroadcastLoad(ElementTy, NumElements);
   }
   bool isLegalMaskedScatter(Type *DataType, Align Alignment) override {

diff  --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 538e099cf76ce..ff73e62d1f332 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -256,7 +256,7 @@ class TargetTransformInfoImplBase {
     return Alignment >= DataSize && isPowerOf2_32(DataSize);
   }
 
-  bool isLegalBroadcastLoad(Type *ElementTy, unsigned NumElements) const {
+  bool isLegalBroadcastLoad(Type *ElementTy, ElementCount NumElements) const {
     return false;
   }
 

diff  --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index 6186f0061eb8d..8a46569e1730c 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -397,7 +397,7 @@ bool TargetTransformInfo::isLegalNTLoad(Type *DataType, Align Alignment) const {
 }
 
 bool TargetTransformInfo::isLegalBroadcastLoad(Type *ElementTy,
-                                               unsigned NumElements) const {
+                                               ElementCount NumElements) const {
   return TTIImpl->isLegalBroadcastLoad(ElementTy, NumElements);
 }
 

diff  --git a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
index abefa62c9d52c..1079bd68468b0 100644
--- a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
+++ b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
@@ -1558,7 +1558,7 @@ InstructionCost X86TTIImpl::getShuffleCost(TTI::ShuffleKind Kind,
       if (const auto *Entry =
               CostTableLookup(SSE3BroadcastLoadTbl, Kind, LT.second)) {
         assert(isLegalBroadcastLoad(BaseTp->getElementType(),
-                                    LT.second.getVectorNumElements()) &&
+                                    LT.second.getVectorElementCount()) &&
                "Table entry missing from isLegalBroadcastLoad()");
         return LT.first * Entry->Cost;
       }
@@ -5137,9 +5137,10 @@ bool X86TTIImpl::isLegalNTStore(Type *DataType, Align Alignment) {
 }
 
 bool X86TTIImpl::isLegalBroadcastLoad(Type *ElementTy,
-                                      unsigned NumElements) const {
+                                      ElementCount NumElements) const {
   // movddup
-  return ST->hasSSE3() && NumElements == 2 &&
+  return ST->hasSSE3() && !NumElements.isScalable() &&
+         NumElements.getFixedValue() == 2 &&
          ElementTy == Type::getDoubleTy(ElementTy->getContext());
 }
 

diff  --git a/llvm/lib/Target/X86/X86TargetTransformInfo.h b/llvm/lib/Target/X86/X86TargetTransformInfo.h
index 4f874783f6989..3dbc2d0b6c2e3 100644
--- a/llvm/lib/Target/X86/X86TargetTransformInfo.h
+++ b/llvm/lib/Target/X86/X86TargetTransformInfo.h
@@ -232,7 +232,7 @@ class X86TTIImpl : public BasicTTIImplBase<X86TTIImpl> {
   bool isLegalMaskedStore(Type *DataType, Align Alignment);
   bool isLegalNTLoad(Type *DataType, Align Alignment);
   bool isLegalNTStore(Type *DataType, Align Alignment);
-  bool isLegalBroadcastLoad(Type *ElementTy, unsigned NumElements) const;
+  bool isLegalBroadcastLoad(Type *ElementTy, ElementCount NumElements) const;
   bool forceScalarizeMaskedGather(VectorType *VTy, Align Alignment);
   bool forceScalarizeMaskedScatter(VectorType *VTy, Align Alignment) {
     return forceScalarizeMaskedGather(VTy, Alignment);

diff  --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index bc976c95dcc9b..2150ed99e79de 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -1188,7 +1188,8 @@ class BoUpSLP {
             return AllUsersVectorized(V1) && AllUsersVectorized(V2);
           };
           // A broadcast of a load can be cheaper on some targets.
-          if (R.TTI->isLegalBroadcastLoad(V1->getType(), NumLanes) &&
+          if (R.TTI->isLegalBroadcastLoad(V1->getType(),
+                                          ElementCount::getFixed(NumLanes)) &&
               ((int)V1->getNumUses() == NumLanes ||
                AllUsersAreInternal(V1, V2)))
             return VLOperands::ScoreSplatLoads;


        


More information about the llvm-commits mailing list