[llvm] f571fe6 - Reland [LoopVectorizer] NFCI: Calculate register usage based on TLI.getTypeLegalizationCost.

Sander de Smalen via llvm-commits llvm-commits at lists.llvm.org
Tue Nov 17 05:47:12 PST 2020


Author: Sander de Smalen
Date: 2020-11-17T13:45:10Z
New Revision: f571fe6df585127d8b045f8e8f5b4e59da9bbb73

URL: https://github.com/llvm/llvm-project/commit/f571fe6df585127d8b045f8e8f5b4e59da9bbb73
DIFF: https://github.com/llvm/llvm-project/commit/f571fe6df585127d8b045f8e8f5b4e59da9bbb73.diff

LOG: Reland [LoopVectorizer] NFCI: Calculate register usage based on TLI.getTypeLegalizationCost.

This relands https://reviews.llvm.org/D91059 and reverts commit
30fded75b48bcbc034120154a57a00c7f3d07e06.

GetRegUsage now returns 0 when Ty is not a valid vector element type.

Added: 
    

Modified: 
    llvm/include/llvm/Analysis/TargetTransformInfo.h
    llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
    llvm/include/llvm/CodeGen/BasicTTIImpl.h
    llvm/lib/Analysis/TargetTransformInfo.cpp
    llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 9baea335d581..77952247fe3c 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -710,6 +710,9 @@ class TargetTransformInfo {
   /// Return true if this type is legal.
   bool isTypeLegal(Type *Ty) const;
 
+  /// Returns the estimated number of registers required to represent \p Ty.
+  unsigned getRegUsageForType(Type *Ty) const;
+
   /// Return true if switches should be turned into lookup tables for the
   /// target.
   bool shouldBuildLookupTables() const;
@@ -1450,6 +1453,7 @@ class TargetTransformInfo::Concept {
   virtual bool isProfitableToHoist(Instruction *I) = 0;
   virtual bool useAA() = 0;
   virtual bool isTypeLegal(Type *Ty) = 0;
+  virtual unsigned getRegUsageForType(Type *Ty) = 0;
   virtual bool shouldBuildLookupTables() = 0;
   virtual bool shouldBuildLookupTablesForConstant(Constant *C) = 0;
   virtual bool useColdCCForColdCall(Function &F) = 0;
@@ -1814,6 +1818,9 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
   }
   bool useAA() override { return Impl.useAA(); }
   bool isTypeLegal(Type *Ty) override { return Impl.isTypeLegal(Ty); }
+  unsigned getRegUsageForType(Type *Ty) override {
+    return Impl.getRegUsageForType(Ty);
+  }
   bool shouldBuildLookupTables() override {
     return Impl.shouldBuildLookupTables();
   }

diff  --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 34e26db8d4f0..ed997406f8a2 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -261,6 +261,8 @@ class TargetTransformInfoImplBase {
 
   bool isTypeLegal(Type *Ty) { return false; }
 
+  unsigned getRegUsageForType(Type *Ty) { return 1; }
+
   bool shouldBuildLookupTables() { return true; }
   bool shouldBuildLookupTablesForConstant(Constant *C) { return true; }
 

diff  --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
index a2ee5d6835f5..0b3d7704028a 100644
--- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h
+++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
@@ -301,6 +301,10 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
     return getTLI()->isTypeLegal(VT);
   }
 
+  unsigned getRegUsageForType(Type *Ty) {
+    return getTLI()->getTypeLegalizationCost(DL, Ty).first;
+  }
+
   int getGEPCost(Type *PointeeType, const Value *Ptr,
                  ArrayRef<const Value *> Operands) {
     return BaseT::getGEPCost(PointeeType, Ptr, Operands);

diff  --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index 315e0060834a..b59a1a403e65 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -486,6 +486,10 @@ bool TargetTransformInfo::isTypeLegal(Type *Ty) const {
   return TTIImpl->isTypeLegal(Ty);
 }
 
+unsigned TargetTransformInfo::getRegUsageForType(Type *Ty) const {
+  return TTIImpl->getRegUsageForType(Ty);
+}
+
 bool TargetTransformInfo::shouldBuildLookupTables() const {
   return TTIImpl->shouldBuildLookupTables();
 }

diff  --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index caa5533cc344..8376bd2ce9e6 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -5793,9 +5793,6 @@ LoopVectorizationCostModel::calculateRegisterUsage(ArrayRef<ElementCount> VFs) {
   unsigned MaxSafeDepDist = -1U;
   if (Legal->getMaxSafeDepDistBytes() != -1U)
     MaxSafeDepDist = Legal->getMaxSafeDepDistBytes() * 8;
-  unsigned WidestRegister =
-      std::min(TTI.getRegisterBitWidth(true), MaxSafeDepDist);
-  const DataLayout &DL = TheFunction->getParent()->getDataLayout();
 
   SmallVector<RegisterUsage, 8> RUs(VFs.size());
   SmallVector<SmallMapVector<unsigned, unsigned, 4>, 8> MaxUsages(VFs.size());
@@ -5803,13 +5800,11 @@ LoopVectorizationCostModel::calculateRegisterUsage(ArrayRef<ElementCount> VFs) {
   LLVM_DEBUG(dbgs() << "LV(REG): Calculating max register usage:\n");
 
   // A lambda that gets the register usage for the given type and VF.
-  auto GetRegUsage = [&DL, WidestRegister](Type *Ty, ElementCount VF) {
-    if (Ty->isTokenTy())
+  const auto &TTICapture = TTI;
+  auto GetRegUsage = [&TTICapture](Type *Ty, ElementCount VF) {
+    if (Ty->isTokenTy() || !VectorType::isValidElementType(Ty))
       return 0U;
-    unsigned TypeSize = DL.getTypeSizeInBits(Ty->getScalarType());
-    assert(!VF.isScalable() && "scalable vectors not yet supported.");
-    return std::max<unsigned>(1, VF.getKnownMinValue() * TypeSize /
-                                     WidestRegister);
+    return TTICapture.getRegUsageForType(VectorType::get(Ty, VF));
   };
 
   for (unsigned int i = 0, s = IdxToInstr.size(); i < s; ++i) {


        


More information about the llvm-commits mailing list