[llvm] b873aba - [LoopVectorizer] NFCI: Calculate register usage based on TLI.getTypeLegalizationCost.

Sander de Smalen via llvm-commits llvm-commits at lists.llvm.org
Wed Nov 11 02:21:03 PST 2020


Author: Sander de Smalen
Date: 2020-11-11T10:18:50Z
New Revision: b873aba3943c067a5efd5303cbdf5aeb0732cf88

URL: https://github.com/llvm/llvm-project/commit/b873aba3943c067a5efd5303cbdf5aeb0732cf88
DIFF: https://github.com/llvm/llvm-project/commit/b873aba3943c067a5efd5303cbdf5aeb0732cf88.diff

LOG: [LoopVectorizer] NFCI: Calculate register usage based on TLI.getTypeLegalizationCost.

This is more accurate than dividing the bitwidth based on the element count by the
maximum register size, as it can just reuse whatever has been calculated for
legalization of these types.

This change is also necessary when calculating register usage for scalable vectors, where
the legalization of these types cannot be done based on the widest register size, because
that does not take the 'vscale' component into account.

Reviewed By: SjoerdMeijer

Differential Revision: https://reviews.llvm.org/D91059

Added: 
    

Modified: 
    llvm/include/llvm/Analysis/TargetTransformInfo.h
    llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
    llvm/include/llvm/CodeGen/BasicTTIImpl.h
    llvm/lib/Analysis/TargetTransformInfo.cpp
    llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 2b9dc2bf129d..b44d3541931b 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -708,6 +708,9 @@ class TargetTransformInfo {
   /// Return true if this type is legal.
   bool isTypeLegal(Type *Ty) const;
 
+  /// Returns the estimated number of registers required to represent \p Ty.
+  unsigned getRegUsageForType(Type *Ty) const;
+
   /// Return true if switches should be turned into lookup tables for the
   /// target.
   bool shouldBuildLookupTables() const;
@@ -1447,6 +1450,7 @@ class TargetTransformInfo::Concept {
   virtual bool isProfitableToHoist(Instruction *I) = 0;
   virtual bool useAA() = 0;
   virtual bool isTypeLegal(Type *Ty) = 0;
+  virtual unsigned getRegUsageForType(Type *Ty) = 0;
   virtual bool shouldBuildLookupTables() = 0;
   virtual bool shouldBuildLookupTablesForConstant(Constant *C) = 0;
   virtual bool useColdCCForColdCall(Function &F) = 0;
@@ -1807,6 +1811,9 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
   }
   bool useAA() override { return Impl.useAA(); }
   bool isTypeLegal(Type *Ty) override { return Impl.isTypeLegal(Ty); }
+  unsigned getRegUsageForType(Type *Ty) override {
+    return Impl.getRegUsageForType(Ty);
+  }
   bool shouldBuildLookupTables() override {
     return Impl.shouldBuildLookupTables();
   }

diff  --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 86cab647c603..28ebf21164da 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -259,6 +259,8 @@ class TargetTransformInfoImplBase {
 
   bool isTypeLegal(Type *Ty) { return false; }
 
+  unsigned getRegUsageForType(Type *Ty) { return 1; }
+
   bool shouldBuildLookupTables() { return true; }
   bool shouldBuildLookupTablesForConstant(Constant *C) { return true; }
 

diff  --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
index 663c9460cfba..04aecca118d8 100644
--- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h
+++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
@@ -297,6 +297,10 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
     return getTLI()->isTypeLegal(VT);
   }
 
+  unsigned getRegUsageForType(Type *Ty) {
+    return getTLI()->getTypeLegalizationCost(DL, Ty).first;
+  }
+
   int getGEPCost(Type *PointeeType, const Value *Ptr,
                  ArrayRef<const Value *> Operands) {
     return BaseT::getGEPCost(PointeeType, Ptr, Operands);

diff  --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index 6443aad1cffb..a9f41ed452c6 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -482,6 +482,10 @@ bool TargetTransformInfo::isTypeLegal(Type *Ty) const {
   return TTIImpl->isTypeLegal(Ty);
 }
 
+unsigned TargetTransformInfo::getRegUsageForType(Type *Ty) const {
+  return TTIImpl->getRegUsageForType(Ty);
+}
+
 bool TargetTransformInfo::shouldBuildLookupTables() const {
   return TTIImpl->shouldBuildLookupTables();
 }

diff  --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index ec136162fbb6..5d50dce7fcbc 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -5793,8 +5793,6 @@ LoopVectorizationCostModel::calculateRegisterUsage(ArrayRef<ElementCount> VFs) {
   unsigned MaxSafeDepDist = -1U;
   if (Legal->getMaxSafeDepDistBytes() != -1U)
     MaxSafeDepDist = Legal->getMaxSafeDepDistBytes() * 8;
-  unsigned WidestRegister =
-      std::min(TTI.getRegisterBitWidth(true), MaxSafeDepDist);
   const DataLayout &DL = TheFunction->getParent()->getDataLayout();
 
   SmallVector<RegisterUsage, 8> RUs(VFs.size());
@@ -5803,13 +5801,10 @@ LoopVectorizationCostModel::calculateRegisterUsage(ArrayRef<ElementCount> VFs) {
   LLVM_DEBUG(dbgs() << "LV(REG): Calculating max register usage:\n");
 
   // A lambda that gets the register usage for the given type and VF.
-  auto GetRegUsage = [&DL, WidestRegister](Type *Ty, ElementCount VF) {
+  auto GetRegUsage = [&DL, &TTI=TTI](Type *Ty, ElementCount VF) {
     if (Ty->isTokenTy())
       return 0U;
-    unsigned TypeSize = DL.getTypeSizeInBits(Ty->getScalarType());
-    assert(!VF.isScalable() && "scalable vectors not yet supported.");
-    return std::max<unsigned>(1, VF.getKnownMinValue() * TypeSize /
-                                     WidestRegister);
+    return TTI.getRegUsageForType(VectorType::get(Ty, VF));
   };
 
   for (unsigned int i = 0, s = IdxToInstr.size(); i < s; ++i) {


        


More information about the llvm-commits mailing list