[llvm] 55d18b3 - [TTI] Return a TypeSize from getRegisterBitWidth.

Sander de Smalen via llvm-commits llvm-commits at lists.llvm.org
Wed Mar 24 07:45:36 PDT 2021


Author: Sander de Smalen
Date: 2021-03-24T14:45:13Z
New Revision: 55d18b3cc236f13e67e22cb8c1f13f92c10306c0

URL: https://github.com/llvm/llvm-project/commit/55d18b3cc236f13e67e22cb8c1f13f92c10306c0
DIFF: https://github.com/llvm/llvm-project/commit/55d18b3cc236f13e67e22cb8c1f13f92c10306c0.diff

LOG: [TTI] Return a TypeSize from getRegisterBitWidth.

This patch changes the interface to take a RegisterKind, to indicate
whether the register bitwidth of a scalar register, fixed-width vector
register, or scalable vector register must be returned.

Reviewed By: paulwalker-arm

Differential Revision: https://reviews.llvm.org/D98874

Added: 
    

Modified: 
    llvm/include/llvm/Analysis/TargetTransformInfo.h
    llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
    llvm/include/llvm/CodeGen/BasicTTIImpl.h
    llvm/lib/Analysis/TargetTransformInfo.cpp
    llvm/lib/CodeGen/TypePromotion.cpp
    llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
    llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.cpp
    llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.h
    llvm/lib/Target/ARM/ARMTargetTransformInfo.h
    llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp
    llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h
    llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
    llvm/lib/Target/PowerPC/PPCTargetTransformInfo.cpp
    llvm/lib/Target/PowerPC/PPCTargetTransformInfo.h
    llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
    llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp
    llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.h
    llvm/lib/Target/VE/VETargetTransformInfo.h
    llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp
    llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.h
    llvm/lib/Target/X86/X86TargetTransformInfo.cpp
    llvm/lib/Target/X86/X86TargetTransformInfo.h
    llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
    llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
    llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index dad1381ea8b8..1038a39bfb3d 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -920,8 +920,10 @@ class TargetTransformInfo {
   /// \return the target-provided register class name
   const char *getRegisterClassName(unsigned ClassID) const;
 
+  enum RegisterKind { RGK_Scalar, RGK_FixedWidthVector, RGK_ScalableVector };
+
   /// \return The width of the largest scalar or vector register type.
-  unsigned getRegisterBitWidth(bool Vector) const;
+  TypeSize getRegisterBitWidth(RegisterKind K) const;
 
   /// \return The width of the smallest vector register type.
   unsigned getMinVectorRegisterBitWidth() const;
@@ -1518,7 +1520,7 @@ class TargetTransformInfo::Concept {
   virtual unsigned getRegisterClassForType(bool Vector,
                                            Type *Ty = nullptr) const = 0;
   virtual const char *getRegisterClassName(unsigned ClassID) const = 0;
-  virtual unsigned getRegisterBitWidth(bool Vector) const = 0;
+  virtual TypeSize getRegisterBitWidth(RegisterKind K) const = 0;
   virtual unsigned getMinVectorRegisterBitWidth() = 0;
   virtual Optional<unsigned> getMaxVScale() const = 0;
   virtual bool shouldMaximizeVectorBandwidth(bool OptSize) const = 0;
@@ -1944,8 +1946,8 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
   const char *getRegisterClassName(unsigned ClassID) const override {
     return Impl.getRegisterClassName(ClassID);
   }
-  unsigned getRegisterBitWidth(bool Vector) const override {
-    return Impl.getRegisterBitWidth(Vector);
+  TypeSize getRegisterBitWidth(RegisterKind K) const override {
+    return Impl.getRegisterBitWidth(K);
   }
   unsigned getMinVectorRegisterBitWidth() override {
     return Impl.getMinVectorRegisterBitWidth();

diff  --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 4cf5337de8cf..b81227759f14 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -380,7 +380,9 @@ class TargetTransformInfoImplBase {
     }
   }
 
-  unsigned getRegisterBitWidth(bool Vector) const { return 32; }
+  TypeSize getRegisterBitWidth(TargetTransformInfo::RegisterKind K) const {
+    return TypeSize::getFixed(32);
+  }
 
   unsigned getMinVectorRegisterBitWidth() const { return 128; }
 

diff  --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
index a0fd0dadbc18..939037edf3d4 100644
--- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h
+++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
@@ -570,7 +570,9 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
   /// \name Vector TTI Implementations
   /// @{
 
-  unsigned getRegisterBitWidth(bool Vector) const { return 32; }
+  TypeSize getRegisterBitWidth(TargetTransformInfo::RegisterKind K) const {
+    return TypeSize::getFixed(32);
+  }
 
   Optional<unsigned> getMaxVScale() const { return None; }
 

diff  --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index 37da50f8015c..7fa6ae13ae48 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -577,8 +577,9 @@ const char *TargetTransformInfo::getRegisterClassName(unsigned ClassID) const {
   return TTIImpl->getRegisterClassName(ClassID);
 }
 
-unsigned TargetTransformInfo::getRegisterBitWidth(bool Vector) const {
-  return TTIImpl->getRegisterBitWidth(Vector);
+TypeSize TargetTransformInfo::getRegisterBitWidth(
+    TargetTransformInfo::RegisterKind K) const {
+  return TTIImpl->getRegisterBitWidth(K);
 }
 
 unsigned TargetTransformInfo::getMinVectorRegisterBitWidth() const {

diff  --git a/llvm/lib/CodeGen/TypePromotion.cpp b/llvm/lib/CodeGen/TypePromotion.cpp
index a42095d8718a..1a488863a8cd 100644
--- a/llvm/lib/CodeGen/TypePromotion.cpp
+++ b/llvm/lib/CodeGen/TypePromotion.cpp
@@ -952,7 +952,8 @@ bool TypePromotion::runOnFunction(Function &F) {
   const TargetLowering *TLI = SubtargetInfo->getTargetLowering();
   const TargetTransformInfo &TII =
     getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
-  RegisterBitWidth = TII.getRegisterBitWidth(false);
+  RegisterBitWidth =
+      TII.getRegisterBitWidth(TargetTransformInfo::RGK_Scalar).getFixedSize();
   Ctx = &F.getParent()->getContext();
 
   // Search up from icmps to try to promote their operands.

diff  --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
index afb470592c8b..5dfa51882820 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
@@ -100,15 +100,19 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
   unsigned getIntrinsicInstrCost(const IntrinsicCostAttributes &ICA,
                                  TTI::TargetCostKind CostKind);
 
-  unsigned getRegisterBitWidth(bool Vector) const {
-    if (Vector) {
+  TypeSize getRegisterBitWidth(TargetTransformInfo::RegisterKind K) const {
+    switch (K) {
+    case TargetTransformInfo::RGK_Scalar:
+      return TypeSize::getFixed(64);
+    case TargetTransformInfo::RGK_FixedWidthVector:
       if (ST->hasSVE())
-        return std::max(ST->getMinSVEVectorSizeInBits(), 128u);
-      if (ST->hasNEON())
-        return 128;
-      return 0;
+        return TypeSize::getFixed(
+            std::max(ST->getMinSVEVectorSizeInBits(), 128u));
+      return TypeSize::getFixed(ST->hasNEON() ? 128 : 0);
+    case TargetTransformInfo::RGK_ScalableVector:
+      return TypeSize::getScalable(ST->hasSVE() ? 128 : 0);
     }
-    return 64;
+    llvm_unreachable("Unsupported register kind");
   }
 
   unsigned getMinVectorRegisterBitWidth() {

diff  --git a/llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.cpp b/llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.cpp
index a6c30ec4f571..ebc685207fe1 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.cpp
@@ -310,8 +310,17 @@ unsigned GCNTTIImpl::getNumberOfRegisters(unsigned RCID) const {
   return getHardwareNumberOfRegisters(false) / NumVGPRs;
 }
 
-unsigned GCNTTIImpl::getRegisterBitWidth(bool Vector) const {
-  return (Vector && ST->hasPackedFP32Ops()) ? 64 : 32;
+TypeSize
+GCNTTIImpl::getRegisterBitWidth(TargetTransformInfo::RegisterKind K) const {
+  switch (K) {
+  case TargetTransformInfo::RGK_Scalar:
+    return TypeSize::getFixed(32);
+  case TargetTransformInfo::RGK_FixedWidthVector:
+    return TypeSize::getFixed(ST->hasPackedFP32Ops() ? 64 : 32);
+  case TargetTransformInfo::RGK_ScalableVector:
+    return TypeSize::getScalable(0);
+  }
+  llvm_unreachable("Unsupported register kind");
 }
 
 unsigned GCNTTIImpl::getMinVectorRegisterBitWidth() const {
@@ -1224,8 +1233,9 @@ unsigned R600TTIImpl::getNumberOfRegisters(bool Vec) const {
   return getHardwareNumberOfRegisters(Vec);
 }
 
-unsigned R600TTIImpl::getRegisterBitWidth(bool Vector) const {
-  return 32;
+TypeSize
+R600TTIImpl::getRegisterBitWidth(TargetTransformInfo::RegisterKind K) const {
+  return TypeSize::getFixed(32);
 }
 
 unsigned R600TTIImpl::getMinVectorRegisterBitWidth() const {

diff  --git a/llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.h b/llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.h
index 9068a2d70c21..02b1df1d44d7 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.h
+++ b/llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.h
@@ -121,7 +121,7 @@ class GCNTTIImpl final : public BasicTTIImplBase<GCNTTIImpl> {
   unsigned getHardwareNumberOfRegisters(bool Vector) const;
   unsigned getNumberOfRegisters(bool Vector) const;
   unsigned getNumberOfRegisters(unsigned RCID) const;
-  unsigned getRegisterBitWidth(bool Vector) const;
+  TypeSize getRegisterBitWidth(TargetTransformInfo::RegisterKind Vector) const;
   unsigned getMinVectorRegisterBitWidth() const;
   unsigned getMaximumVF(unsigned ElemWidth, unsigned Opcode) const;
   unsigned getLoadVectorFactor(unsigned VF, unsigned LoadSize,
@@ -243,7 +243,7 @@ class R600TTIImpl final : public BasicTTIImplBase<R600TTIImpl> {
                              TTI::PeelingPreferences &PP);
   unsigned getHardwareNumberOfRegisters(bool Vec) const;
   unsigned getNumberOfRegisters(bool Vec) const;
-  unsigned getRegisterBitWidth(bool Vector) const;
+  TypeSize getRegisterBitWidth(TargetTransformInfo::RegisterKind Vector) const;
   unsigned getMinVectorRegisterBitWidth() const;
   unsigned getLoadStoreVecRegBitWidth(unsigned AddrSpace) const;
   bool isLegalToVectorizeMemChain(unsigned ChainSizeInBytes, Align Alignment,

diff  --git a/llvm/lib/Target/ARM/ARMTargetTransformInfo.h b/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
index 39858a8282d4..79e63b292468 100644
--- a/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
+++ b/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
@@ -149,16 +149,20 @@ class ARMTTIImpl : public BasicTTIImplBase<ARMTTIImpl> {
     return 13;
   }
 
-  unsigned getRegisterBitWidth(bool Vector) const {
-    if (Vector) {
+  TypeSize getRegisterBitWidth(TargetTransformInfo::RegisterKind K) const {
+    switch (K) {
+    case TargetTransformInfo::RGK_Scalar:
+      return TypeSize::getFixed(32);
+    case TargetTransformInfo::RGK_FixedWidthVector:
       if (ST->hasNEON())
-        return 128;
+        return TypeSize::getFixed(128);
       if (ST->hasMVEIntegerOps())
-        return 128;
-      return 0;
+        return TypeSize::getFixed(128);
+      return TypeSize::getFixed(0);
+    case TargetTransformInfo::RGK_ScalableVector:
+      return TypeSize::getScalable(0);
     }
-
-    return 32;
+    llvm_unreachable("Unsupported register kind");
   }
 
   unsigned getMaxInterleaveFactor(unsigned VF) {

diff  --git a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp
index 95b7c79c84f1..bfbba4994938 100644
--- a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp
+++ b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp
@@ -98,8 +98,18 @@ unsigned HexagonTTIImpl::getMaxInterleaveFactor(unsigned VF) {
   return useHVX() ? 2 : 1;
 }
 
-unsigned HexagonTTIImpl::getRegisterBitWidth(bool Vector) const {
-  return Vector ? getMinVectorRegisterBitWidth() : 32;
+TypeSize
+HexagonTTIImpl::getRegisterBitWidth(TargetTransformInfo::RegisterKind K) const {
+  switch (K) {
+  case TargetTransformInfo::RGK_Scalar:
+    return TypeSize::getFixed(32);
+  case TargetTransformInfo::RGK_FixedWidthVector:
+    return TypeSize::getFixed(getMinVectorRegisterBitWidth());
+  case TargetTransformInfo::RGK_ScalableVector:
+    return TypeSize::getScalable(0);
+  }
+
+  llvm_unreachable("Unsupported register kind");
 }
 
 unsigned HexagonTTIImpl::getMinVectorRegisterBitWidth() const {
@@ -162,7 +172,9 @@ unsigned HexagonTTIImpl::getMemoryOpCost(unsigned Opcode, Type *Src,
     VectorType *VecTy = cast<VectorType>(Src);
     unsigned VecWidth = VecTy->getPrimitiveSizeInBits().getFixedSize();
     if (useHVX() && ST.isTypeForHVX(VecTy)) {
-      unsigned RegWidth = getRegisterBitWidth(true);
+      unsigned RegWidth =
+          getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector)
+              .getFixedSize();
       assert(RegWidth && "Non-zero vector register width expected");
       // Cost of HVX loads.
       if (VecWidth % RegWidth == 0)

diff  --git a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h
index d68787d465fd..fd04d1df0dea 100644
--- a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h
+++ b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h
@@ -81,7 +81,7 @@ class HexagonTTIImpl : public BasicTTIImplBase<HexagonTTIImpl> {
 
   unsigned getNumberOfRegisters(bool vector) const;
   unsigned getMaxInterleaveFactor(unsigned VF);
-  unsigned getRegisterBitWidth(bool Vector) const;
+  TypeSize getRegisterBitWidth(TargetTransformInfo::RegisterKind K) const;
   unsigned getMinVectorRegisterBitWidth() const;
   ElementCount getMinimumVF(unsigned ElemWidth, bool IsScalable) const;
 

diff  --git a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
index 6f071040dd9d..7d73d8f9cfd8 100644
--- a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
+++ b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
@@ -71,7 +71,9 @@ class NVPTXTTIImpl : public BasicTTIImplBase<NVPTXTTIImpl> {
 
   // Only <2 x half> should be vectorized, so always return 32 for the vector
   // register size.
-  unsigned getRegisterBitWidth(bool Vector) const { return 32; }
+  TypeSize getRegisterBitWidth(TargetTransformInfo::RegisterKind K) const {
+    return TypeSize::getFixed(32);
+  }
   unsigned getMinVectorRegisterBitWidth() const { return 32; }
 
   // We don't want to prevent inlining because of target-cpu and -features

diff  --git a/llvm/lib/Target/PowerPC/PPCTargetTransformInfo.cpp b/llvm/lib/Target/PowerPC/PPCTargetTransformInfo.cpp
index c2b8c1937666..f8145fe4d3ef 100644
--- a/llvm/lib/Target/PowerPC/PPCTargetTransformInfo.cpp
+++ b/llvm/lib/Target/PowerPC/PPCTargetTransformInfo.cpp
@@ -871,16 +871,18 @@ const char* PPCTTIImpl::getRegisterClassName(unsigned ClassID) const {
   }
 }
 
-unsigned PPCTTIImpl::getRegisterBitWidth(bool Vector) const {
-  if (Vector) {
-    if (ST->hasAltivec()) return 128;
-    return 0;
+TypeSize
+PPCTTIImpl::getRegisterBitWidth(TargetTransformInfo::RegisterKind K) const {
+  switch (K) {
+  case TargetTransformInfo::RGK_Scalar:
+    return TypeSize::getFixed(ST->isPPC64() ? 64 : 32);
+  case TargetTransformInfo::RGK_FixedWidthVector:
+    return TypeSize::getFixed(ST->hasAltivec() ? 128 : 0);
+  case TargetTransformInfo::RGK_ScalableVector:
+    return TypeSize::getScalable(0);
   }
 
-  if (ST->isPPC64())
-    return 64;
-  return 32;
-
+  llvm_unreachable("Unsupported register kind");
 }
 
 unsigned PPCTTIImpl::getCacheLineSize() const {

diff  --git a/llvm/lib/Target/PowerPC/PPCTargetTransformInfo.h b/llvm/lib/Target/PowerPC/PPCTargetTransformInfo.h
index d1f6a9ccabfb..87ff0e32bdea 100644
--- a/llvm/lib/Target/PowerPC/PPCTargetTransformInfo.h
+++ b/llvm/lib/Target/PowerPC/PPCTargetTransformInfo.h
@@ -93,7 +93,7 @@ class PPCTTIImpl : public BasicTTIImplBase<PPCTTIImpl> {
   unsigned getNumberOfRegisters(unsigned ClassID) const;
   unsigned getRegisterClassForType(bool Vector, Type *Ty = nullptr) const;
   const char* getRegisterClassName(unsigned ClassID) const;
-  unsigned getRegisterBitWidth(bool Vector) const;
+  TypeSize getRegisterBitWidth(TargetTransformInfo::RegisterKind K) const;
   unsigned getCacheLineSize() const override;
   unsigned getPrefetchDistance() const override;
   unsigned getMaxInterleaveFactor(unsigned VF);

diff  --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
index bb8215b736ca..eef5f7953a94 100644
--- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
+++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
@@ -52,13 +52,19 @@ class RISCVTTIImpl : public BasicTTIImplBase<RISCVTTIImpl> {
   bool supportsScalableVectors() const { return ST->hasStdExtV(); }
   Optional<unsigned> getMaxVScale() const;
 
-  unsigned getRegisterBitWidth(bool Vector) const {
-    if (Vector) {
-      if (ST->hasStdExtV())
-        return ST->getMinRVVVectorSizeInBits();
-      return 0;
+  TypeSize getRegisterBitWidth(TargetTransformInfo::RegisterKind K) const {
+    switch (K) {
+    case TargetTransformInfo::RGK_Scalar:
+      return TypeSize::getFixed(ST->getXLen());
+    case TargetTransformInfo::RGK_FixedWidthVector:
+      return TypeSize::getFixed(
+          ST->hasStdExtV() ? ST->getMinRVVVectorSizeInBits() : 0);
+    case TargetTransformInfo::RGK_ScalableVector:
+      return TypeSize::getScalable(
+          ST->hasStdExtV() ? ST->getMinRVVVectorSizeInBits() : 0);
     }
-    return ST->getXLen();
+
+    llvm_unreachable("Unsupported register kind");
   }
 
   bool isLegalElementTypeForRVV(Type *ScalarTy) {

diff  --git a/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp b/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp
index 6ef5277bc5d4..1274dab7c025 100644
--- a/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp
+++ b/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp
@@ -323,12 +323,18 @@ unsigned SystemZTTIImpl::getNumberOfRegisters(unsigned ClassID) const {
   return 0;
 }
 
-unsigned SystemZTTIImpl::getRegisterBitWidth(bool Vector) const {
-  if (!Vector)
-    return 64;
-  if (ST->hasVector())
-    return 128;
-  return 0;
+TypeSize
+SystemZTTIImpl::getRegisterBitWidth(TargetTransformInfo::RegisterKind K) const {
+  switch (K) {
+  case TargetTransformInfo::RGK_Scalar:
+    return TypeSize::getFixed(64);
+  case TargetTransformInfo::RGK_FixedWidthVector:
+    return TypeSize::getFixed(ST->hasVector() ? 128 : 0);
+  case TargetTransformInfo::RGK_ScalableVector:
+    return TypeSize::getScalable(0);
+  }
+
+  llvm_unreachable("Unsupported register kind");
 }
 
 unsigned SystemZTTIImpl::getMinPrefetchStride(unsigned NumMemAccesses,

diff  --git a/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.h b/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.h
index f3d068af02e4..0fe50c2efd16 100644
--- a/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.h
+++ b/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.h
@@ -62,7 +62,7 @@ class SystemZTTIImpl : public BasicTTIImplBase<SystemZTTIImpl> {
   /// @{
 
   unsigned getNumberOfRegisters(unsigned ClassID) const;
-  unsigned getRegisterBitWidth(bool Vector) const;
+  TypeSize getRegisterBitWidth(TargetTransformInfo::RegisterKind K) const;
 
   unsigned getCacheLineSize() const override { return 256; }
   unsigned getPrefetchDistance() const override { return 4500; }

diff  --git a/llvm/lib/Target/VE/VETargetTransformInfo.h b/llvm/lib/Target/VE/VETargetTransformInfo.h
index 68af66597485..6730c43258d2 100644
--- a/llvm/lib/Target/VE/VETargetTransformInfo.h
+++ b/llvm/lib/Target/VE/VETargetTransformInfo.h
@@ -50,12 +50,18 @@ class VETTIImpl : public BasicTTIImplBase<VETTIImpl> {
     return 64;
   }
 
-  unsigned getRegisterBitWidth(bool Vector) const {
-    if (Vector) {
+  TypeSize getRegisterBitWidth(TargetTransformInfo::RegisterKind K) const {
+    switch (K) {
+    case TargetTransformInfo::RGK_Scalar:
+      return TypeSize::getFixed(64);
+    case TargetTransformInfo::RGK_FixedWidthVector:
       // TODO report vregs once vector isel is stable.
-      return 0;
+      return TypeSize::getFixed(0);
+    case TargetTransformInfo::RGK_ScalableVector:
+      return TypeSize::getScalable(0);
     }
-    return 64;
+
+    llvm_unreachable("Unsupported register kind");
   }
 
   unsigned getMinVectorRegisterBitWidth() const {

diff  --git a/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp
index a2111afbfcba..1357860c27c9 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp
@@ -36,11 +36,18 @@ unsigned WebAssemblyTTIImpl::getNumberOfRegisters(unsigned ClassID) const {
   return Result;
 }
 
-unsigned WebAssemblyTTIImpl::getRegisterBitWidth(bool Vector) const {
-  if (Vector && getST()->hasSIMD128())
-    return 128;
+TypeSize WebAssemblyTTIImpl::getRegisterBitWidth(
+    TargetTransformInfo::RegisterKind K) const {
+  switch (K) {
+  case TargetTransformInfo::RGK_Scalar:
+    return TypeSize::getFixed(64);
+  case TargetTransformInfo::RGK_FixedWidthVector:
+    return TypeSize::getFixed(getST()->hasSIMD128() ? 128 : 64);
+  case TargetTransformInfo::RGK_ScalableVector:
+    return TypeSize::getScalable(0);
+  }
 
-  return 64;
+  llvm_unreachable("Unsupported register kind");
 }
 
 unsigned WebAssemblyTTIImpl::getArithmeticInstrCost(

diff  --git a/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.h b/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.h
index 3515a3e149d1..8b9719fe7ff2 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.h
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.h
@@ -57,7 +57,7 @@ class WebAssemblyTTIImpl final : public BasicTTIImplBase<WebAssemblyTTIImpl> {
   /// @{
 
   unsigned getNumberOfRegisters(unsigned ClassID) const;
-  unsigned getRegisterBitWidth(bool Vector) const;
+  TypeSize getRegisterBitWidth(TargetTransformInfo::RegisterKind K) const;
   unsigned getArithmeticInstrCost(
       unsigned Opcode, Type *Ty,
       TTI::TargetCostKind CostKind = TTI::TCK_SizeAndLatency,

diff  --git a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
index 3b8856f656f0..0a9582edd211 100644
--- a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
+++ b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
@@ -129,26 +129,30 @@ unsigned X86TTIImpl::getNumberOfRegisters(unsigned ClassID) const {
   return 8;
 }
 
-unsigned X86TTIImpl::getRegisterBitWidth(bool Vector) const {
+TypeSize
+X86TTIImpl::getRegisterBitWidth(TargetTransformInfo::RegisterKind K) const {
   unsigned PreferVectorWidth = ST->getPreferVectorWidth();
-  if (Vector) {
+  switch (K) {
+  case TargetTransformInfo::RGK_Scalar:
+    return TypeSize::getFixed(ST->is64Bit() ? 64 : 32);
+  case TargetTransformInfo::RGK_FixedWidthVector:
     if (ST->hasAVX512() && PreferVectorWidth >= 512)
-      return 512;
+      return TypeSize::getFixed(512);
     if (ST->hasAVX() && PreferVectorWidth >= 256)
-      return 256;
+      return TypeSize::getFixed(256);
     if (ST->hasSSE1() && PreferVectorWidth >= 128)
-      return 128;
-    return 0;
+      return TypeSize::getFixed(128);
+    return TypeSize::getFixed(0);
+  case TargetTransformInfo::RGK_ScalableVector:
+    return TypeSize::getScalable(0);
   }
 
-  if (ST->is64Bit())
-    return 64;
-
-  return 32;
+  llvm_unreachable("Unsupported register kind");
 }
 
 unsigned X86TTIImpl::getLoadStoreVecRegBitWidth(unsigned) const {
-  return getRegisterBitWidth(true);
+  return getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector)
+      .getFixedSize();
 }
 
 unsigned X86TTIImpl::getMaxInterleaveFactor(unsigned VF) {

diff  --git a/llvm/lib/Target/X86/X86TargetTransformInfo.h b/llvm/lib/Target/X86/X86TargetTransformInfo.h
index fb3ab46c87ac..8b2f919347c7 100644
--- a/llvm/lib/Target/X86/X86TargetTransformInfo.h
+++ b/llvm/lib/Target/X86/X86TargetTransformInfo.h
@@ -115,7 +115,7 @@ class X86TTIImpl : public BasicTTIImplBase<X86TTIImpl> {
   /// @{
 
   unsigned getNumberOfRegisters(unsigned ClassID) const;
-  unsigned getRegisterBitWidth(bool Vector) const;
+  TypeSize getRegisterBitWidth(TargetTransformInfo::RegisterKind K) const;
   unsigned getLoadStoreVecRegBitWidth(unsigned AS) const;
   unsigned getMaxInterleaveFactor(unsigned VF);
   int getArithmeticInstrCost(

diff  --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index e07eb6593200..544f67cd79e2 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -413,7 +413,9 @@ class LowerMatrixIntrinsics {
   /// \p VT * N.
   unsigned getNumOps(Type *ST, unsigned N) {
     return std::ceil((ST->getPrimitiveSizeInBits() * N).getFixedSize() /
-                     double(TTI.getRegisterBitWidth(true)));
+                     double(TTI.getRegisterBitWidth(
+                                   TargetTransformInfo::RGK_FixedWidthVector)
+                                .getFixedSize()));
   }
 
   /// Return the set of vectors that a matrix value is lowered to.
@@ -1013,7 +1015,8 @@ class LowerMatrixIntrinsics {
                           const MatrixTy &B, bool AllowContraction,
                           IRBuilder<> &Builder, bool isTiled) {
     const unsigned VF = std::max<unsigned>(
-        TTI.getRegisterBitWidth(true) /
+        TTI.getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector)
+                .getFixedSize() /
             Result.getElementType()->getPrimitiveSizeInBits().getFixedSize(),
         1U);
     unsigned R = Result.getNumRows();
@@ -1179,10 +1182,11 @@ class LowerMatrixIntrinsics {
     const unsigned M = LShape.NumColumns;
     auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
 
-    const unsigned VF =
-        std::max<unsigned>(TTI.getRegisterBitWidth(true) /
-                               EltType->getPrimitiveSizeInBits().getFixedSize(),
-                           1U);
+    const unsigned VF = std::max<unsigned>(
+        TTI.getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector)
+                .getFixedSize() /
+            EltType->getPrimitiveSizeInBits().getFixedSize(),
+        1U);
 
     // Cost model for tiling
     //

diff  --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index c60880cc452a..34caa9acd4da 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -5700,7 +5700,9 @@ LoopVectorizationCostModel::computeFeasibleMaxVF(unsigned ConstTripCount,
   MinBWs = computeMinimumValueSizes(TheLoop->getBlocks(), *DB, &TTI);
   unsigned SmallestType, WidestType;
   std::tie(SmallestType, WidestType) = getSmallestAndWidestTypes();
-  unsigned WidestRegister = TTI.getRegisterBitWidth(true);
+  unsigned WidestRegister =
+      TTI.getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector)
+          .getFixedSize();
 
   // Get the maximum safe dependence distance in bits computed by LAA.
   // It is computed by MaxVF * sizeOf(type) * 8, where type is taken from
@@ -7672,8 +7674,10 @@ LoopVectorizationPlanner::planInVPlanNativePath(ElementCount UserVF) {
     // If the user doesn't provide a vectorization factor, determine a
     // reasonable one.
     if (UserVF.isZero()) {
-      VF = ElementCount::getFixed(
-          determineVPlanVF(TTI->getRegisterBitWidth(true /* Vector*/), CM));
+      VF = ElementCount::getFixed(determineVPlanVF(
+          TTI->getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector)
+              .getFixedSize(),
+          CM));
       LLVM_DEBUG(dbgs() << "LV: VPlan computed VF " << VF << ".\n");
 
       // Make sure we have a VF > 1 for stress testing.

diff  --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index 78d2ea0032db..d128312bc69f 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -574,7 +574,9 @@ class BoUpSLP {
     if (MaxVectorRegSizeOption.getNumOccurrences())
       MaxVecRegSize = MaxVectorRegSizeOption;
     else
-      MaxVecRegSize = TTI->getRegisterBitWidth(true);
+      MaxVecRegSize =
+          TTI->getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector)
+              .getFixedSize();
 
     if (MinVectorRegSizeOption.getNumOccurrences())
       MinVecRegSize = MinVectorRegSizeOption;


        


More information about the llvm-commits mailing list