[llvm] 2e8d815 - [TTI] Support scalable offsets in getScalingFactorCost (#88113)

via llvm-commits llvm-commits at lists.llvm.org
Fri May 10 03:22:15 PDT 2024


Author: Graham Hunter
Date: 2024-05-10T11:22:11+01:00
New Revision: 2e8d8155969f90b8f17634ce9a8e4541fb21dbab

URL: https://github.com/llvm/llvm-project/commit/2e8d8155969f90b8f17634ce9a8e4541fb21dbab
DIFF: https://github.com/llvm/llvm-project/commit/2e8d8155969f90b8f17634ce9a8e4541fb21dbab.diff

LOG: [TTI] Support scalable offsets in getScalingFactorCost (#88113)

Part of the work to support vscale-relative immediates in LSR.

Added: 
    

Modified: 
    llvm/include/llvm/Analysis/TargetTransformInfo.h
    llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
    llvm/include/llvm/CodeGen/BasicTTIImpl.h
    llvm/lib/Analysis/TargetTransformInfo.cpp
    llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
    llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
    llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
    llvm/lib/Target/ARM/ARMTargetTransformInfo.h
    llvm/lib/Target/X86/X86TargetTransformInfo.cpp
    llvm/lib/Target/X86/X86TargetTransformInfo.h
    llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 1c76821fe5e4a..f0eb83c143e2c 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -834,7 +834,7 @@ class TargetTransformInfo {
   /// If the AM is not supported, it returns a negative value.
   /// TODO: Handle pre/postinc as well.
   InstructionCost getScalingFactorCost(Type *Ty, GlobalValue *BaseGV,
-                                       int64_t BaseOffset, bool HasBaseReg,
+                                       StackOffset BaseOffset, bool HasBaseReg,
                                        int64_t Scale,
                                        unsigned AddrSpace = 0) const;
 
@@ -1891,7 +1891,7 @@ class TargetTransformInfo::Concept {
   virtual bool hasVolatileVariant(Instruction *I, unsigned AddrSpace) = 0;
   virtual bool prefersVectorizedAddressing() = 0;
   virtual InstructionCost getScalingFactorCost(Type *Ty, GlobalValue *BaseGV,
-                                               int64_t BaseOffset,
+                                               StackOffset BaseOffset,
                                                bool HasBaseReg, int64_t Scale,
                                                unsigned AddrSpace) = 0;
   virtual bool LSRWithInstrQueries() = 0;
@@ -2403,7 +2403,7 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
     return Impl.prefersVectorizedAddressing();
   }
   InstructionCost getScalingFactorCost(Type *Ty, GlobalValue *BaseGV,
-                                       int64_t BaseOffset, bool HasBaseReg,
+                                       StackOffset BaseOffset, bool HasBaseReg,
                                        int64_t Scale,
                                        unsigned AddrSpace) override {
     return Impl.getScalingFactorCost(Ty, BaseGV, BaseOffset, HasBaseReg, Scale,

diff  --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 4d5cd963e0926..262ebdb3cbef9 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -32,6 +32,7 @@ class Function;
 /// Base class for use as a mix-in that aids implementing
 /// a TargetTransformInfo-compatible class.
 class TargetTransformInfoImplBase {
+
 protected:
   typedef TargetTransformInfo TTI;
 
@@ -326,12 +327,13 @@ class TargetTransformInfoImplBase {
   bool prefersVectorizedAddressing() const { return true; }
 
   InstructionCost getScalingFactorCost(Type *Ty, GlobalValue *BaseGV,
-                                       int64_t BaseOffset, bool HasBaseReg,
+                                       StackOffset BaseOffset, bool HasBaseReg,
                                        int64_t Scale,
                                        unsigned AddrSpace) const {
     // Guess that all legal addressing mode are free.
-    if (isLegalAddressingMode(Ty, BaseGV, BaseOffset, HasBaseReg, Scale,
-                              AddrSpace))
+    if (isLegalAddressingMode(Ty, BaseGV, BaseOffset.getFixed(), HasBaseReg,
+                              Scale, AddrSpace, /*I=*/nullptr,
+                              BaseOffset.getScalable()))
       return 0;
     return -1;
   }

diff  --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
index bcb60c6562967..fa481886b268a 100644
--- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h
+++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
@@ -404,13 +404,14 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
   }
 
   InstructionCost getScalingFactorCost(Type *Ty, GlobalValue *BaseGV,
-                                       int64_t BaseOffset, bool HasBaseReg,
+                                       StackOffset BaseOffset, bool HasBaseReg,
                                        int64_t Scale, unsigned AddrSpace) {
     TargetLoweringBase::AddrMode AM;
     AM.BaseGV = BaseGV;
-    AM.BaseOffs = BaseOffset;
+    AM.BaseOffs = BaseOffset.getFixed();
     AM.HasBaseReg = HasBaseReg;
     AM.Scale = Scale;
+    AM.ScalableOffset = BaseOffset.getScalable();
     if (getTLI()->isLegalAddressingMode(DL, AM, Ty, AddrSpace))
       return 0;
     return -1;

diff  --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index 33c899fe88999..00443ace46f74 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -531,7 +531,7 @@ bool TargetTransformInfo::prefersVectorizedAddressing() const {
 }
 
 InstructionCost TargetTransformInfo::getScalingFactorCost(
-    Type *Ty, GlobalValue *BaseGV, int64_t BaseOffset, bool HasBaseReg,
+    Type *Ty, GlobalValue *BaseGV, StackOffset BaseOffset, bool HasBaseReg,
     int64_t Scale, unsigned AddrSpace) const {
   InstructionCost Cost = TTIImpl->getScalingFactorCost(
       Ty, BaseGV, BaseOffset, HasBaseReg, Scale, AddrSpace);

diff  --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index 4b410826f4bb1..f49c73dc79519 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -4183,7 +4183,7 @@ bool AArch64TTIImpl::preferPredicateOverEpilogue(TailFoldingInfo *TFI) {
 
 InstructionCost
 AArch64TTIImpl::getScalingFactorCost(Type *Ty, GlobalValue *BaseGV,
-                                     int64_t BaseOffset, bool HasBaseReg,
+                                     StackOffset BaseOffset, bool HasBaseReg,
                                      int64_t Scale, unsigned AddrSpace) const {
   // Scaling factors are not free at all.
   // Operands                     | Rt Latency
@@ -4194,9 +4194,10 @@ AArch64TTIImpl::getScalingFactorCost(Type *Ty, GlobalValue *BaseGV,
   // Rt, [Xn, Wm, <extend> #imm]  |
   TargetLoweringBase::AddrMode AM;
   AM.BaseGV = BaseGV;
-  AM.BaseOffs = BaseOffset;
+  AM.BaseOffs = BaseOffset.getFixed();
   AM.HasBaseReg = HasBaseReg;
   AM.Scale = Scale;
+  AM.ScalableOffset = BaseOffset.getScalable();
   if (getTLI()->isLegalAddressingMode(DL, AM, Ty, AddrSpace))
     // Scale represents reg2 * scale, thus account for 1 if
     // it is not equal to 0 or 1.

diff  --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
index 678c132e6a80a..2f44aaa3e26ab 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
@@ -407,7 +407,7 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
   /// If the AM is supported, the return value must be >= 0.
   /// If the AM is not supported, it returns a negative value.
   InstructionCost getScalingFactorCost(Type *Ty, GlobalValue *BaseGV,
-                                       int64_t BaseOffset, bool HasBaseReg,
+                                       StackOffset BaseOffset, bool HasBaseReg,
                                        int64_t Scale, unsigned AddrSpace) const;
   /// @}
 

diff  --git a/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp b/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
index ee87f7f0e555e..7db2e8ee7e6f9 100644
--- a/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
+++ b/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
@@ -2571,14 +2571,15 @@ bool ARMTTIImpl::preferPredicatedReductionSelect(
 }
 
 InstructionCost ARMTTIImpl::getScalingFactorCost(Type *Ty, GlobalValue *BaseGV,
-                                                 int64_t BaseOffset,
+                                                 StackOffset BaseOffset,
                                                  bool HasBaseReg, int64_t Scale,
                                                  unsigned AddrSpace) const {
   TargetLoweringBase::AddrMode AM;
   AM.BaseGV = BaseGV;
-  AM.BaseOffs = BaseOffset;
+  AM.BaseOffs = BaseOffset.getFixed();
   AM.HasBaseReg = HasBaseReg;
   AM.Scale = Scale;
+  AM.ScalableOffset = BaseOffset.getScalable();
   if (getTLI()->isLegalAddressingMode(DL, AM, Ty, AddrSpace)) {
     if (ST->hasFPAO())
       return AM.Scale < 0 ? 1 : 0; // positive offsets execute faster

diff  --git a/llvm/lib/Target/ARM/ARMTargetTransformInfo.h b/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
index 58eab45b9641f..8c4b92b856888 100644
--- a/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
+++ b/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
@@ -303,7 +303,7 @@ class ARMTTIImpl : public BasicTTIImplBase<ARMTTIImpl> {
   /// If the AM is supported, the return value must be >= 0.
   /// If the AM is not supported, the return value must be negative.
   InstructionCost getScalingFactorCost(Type *Ty, GlobalValue *BaseGV,
-                                       int64_t BaseOffset, bool HasBaseReg,
+                                       StackOffset BaseOffset, bool HasBaseReg,
                                        int64_t Scale, unsigned AddrSpace) const;
 
   bool maybeLoweredToCall(Instruction &I);

diff  --git a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
index 6b7cddc6d72e4..d43480d0a0125 100644
--- a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
+++ b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
@@ -6741,7 +6741,7 @@ InstructionCost X86TTIImpl::getInterleavedMemoryOpCost(
 }
 
 InstructionCost X86TTIImpl::getScalingFactorCost(Type *Ty, GlobalValue *BaseGV,
-                                                 int64_t BaseOffset,
+                                                 StackOffset BaseOffset,
                                                  bool HasBaseReg, int64_t Scale,
                                                  unsigned AddrSpace) const {
   // Scaling factors are not free at all.
@@ -6764,9 +6764,10 @@ InstructionCost X86TTIImpl::getScalingFactorCost(Type *Ty, GlobalValue *BaseGV,
   // vmovaps %ymm1, (%r8) can use port 2, 3, or 7.
   TargetLoweringBase::AddrMode AM;
   AM.BaseGV = BaseGV;
-  AM.BaseOffs = BaseOffset;
+  AM.BaseOffs = BaseOffset.getFixed();
   AM.HasBaseReg = HasBaseReg;
   AM.Scale = Scale;
+  AM.ScalableOffset = BaseOffset.getScalable();
   if (getTLI()->isLegalAddressingMode(DL, AM, Ty, AddrSpace))
     // Scale represents reg2 * scale, thus account for 1
     // as soon as we use a second register.

diff  --git a/llvm/lib/Target/X86/X86TargetTransformInfo.h b/llvm/lib/Target/X86/X86TargetTransformInfo.h
index b501930745739..d720cc136b8ae 100644
--- a/llvm/lib/Target/X86/X86TargetTransformInfo.h
+++ b/llvm/lib/Target/X86/X86TargetTransformInfo.h
@@ -253,7 +253,7 @@ class X86TTIImpl : public BasicTTIImplBase<X86TTIImpl> {
   /// If the AM is supported, the return value must be >= 0.
   /// If the AM is not supported, it returns a negative value.
   InstructionCost getScalingFactorCost(Type *Ty, GlobalValue *BaseGV,
-                                       int64_t BaseOffset, bool HasBaseReg,
+                                       StackOffset BaseOffset, bool HasBaseReg,
                                        int64_t Scale, unsigned AddrSpace) const;
 
   bool isLSRCostLess(const TargetTransformInfo::LSRCost &C1,

diff  --git a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp
index ec42e2d6e193a..eb1904ccaff35 100644
--- a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp
@@ -1817,10 +1817,12 @@ static InstructionCost getScalingFactorCost(const TargetTransformInfo &TTI,
   case LSRUse::Address: {
     // Check the scaling factor cost with both the min and max offsets.
     InstructionCost ScaleCostMinOffset = TTI.getScalingFactorCost(
-        LU.AccessTy.MemTy, F.BaseGV, F.BaseOffset + LU.MinOffset, F.HasBaseReg,
+        LU.AccessTy.MemTy, F.BaseGV,
+        StackOffset::getFixed(F.BaseOffset + LU.MinOffset), F.HasBaseReg,
         F.Scale, LU.AccessTy.AddrSpace);
     InstructionCost ScaleCostMaxOffset = TTI.getScalingFactorCost(
-        LU.AccessTy.MemTy, F.BaseGV, F.BaseOffset + LU.MaxOffset, F.HasBaseReg,
+        LU.AccessTy.MemTy, F.BaseGV,
+        StackOffset::getFixed(F.BaseOffset + LU.MaxOffset), F.HasBaseReg,
         F.Scale, LU.AccessTy.AddrSpace);
 
     assert(ScaleCostMinOffset.isValid() && ScaleCostMaxOffset.isValid() &&


        


More information about the llvm-commits mailing list