[llvm] [TTI] Support scalable offsets in getScalingFactorCost (PR #88113)

via llvm-commits llvm-commits at lists.llvm.org
Tue Apr 9 05:16:05 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-arm

@llvm/pr-subscribers-llvm-analysis

Author: Graham Hunter (huntergr-arm)

<details>
<summary>Changes</summary>

Part of the work to support vscale-relative immediates in LSR.

No tests added yet, but I'd like feedback on the approach.

---
Full diff: https://github.com/llvm/llvm-project/pull/88113.diff


10 Files Affected:

- (modified) llvm/include/llvm/Analysis/TargetTransformInfo.h (+7-6) 
- (modified) llvm/include/llvm/Analysis/TargetTransformInfoImpl.h (+3-3) 
- (modified) llvm/include/llvm/CodeGen/BasicTTIImpl.h (+3-1) 
- (modified) llvm/lib/Analysis/TargetTransformInfo.cpp (+2-2) 
- (modified) llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp (+4-4) 
- (modified) llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h (+2-1) 
- (modified) llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp (+3-1) 
- (modified) llvm/lib/Target/ARM/ARMTargetTransformInfo.h (+2-1) 
- (modified) llvm/lib/Target/X86/X86TargetTransformInfo.cpp (+2-1) 
- (modified) llvm/lib/Target/X86/X86TargetTransformInfo.h (+2-1) 


``````````diff
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index fa9392b86c15b9..4c6b8e312786cc 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -835,8 +835,8 @@ class TargetTransformInfo {
   /// TODO: Handle pre/postinc as well.
   InstructionCost getScalingFactorCost(Type *Ty, GlobalValue *BaseGV,
                                        int64_t BaseOffset, bool HasBaseReg,
-                                       int64_t Scale,
-                                       unsigned AddrSpace = 0) const;
+                                       int64_t Scale, unsigned AddrSpace = 0,
+                                       int64_t ScalableOffset = 0) const;
 
   /// Return true if the loop strength reduce pass should make
   /// Instruction* based TTI queries to isLegalAddressingMode(). This is
@@ -1894,7 +1894,8 @@ class TargetTransformInfo::Concept {
   virtual InstructionCost getScalingFactorCost(Type *Ty, GlobalValue *BaseGV,
                                                int64_t BaseOffset,
                                                bool HasBaseReg, int64_t Scale,
-                                               unsigned AddrSpace) = 0;
+                                               unsigned AddrSpace,
+                                               int64_t ScalableOffset) = 0;
   virtual bool LSRWithInstrQueries() = 0;
   virtual bool isTruncateFree(Type *Ty1, Type *Ty2) = 0;
   virtual bool isProfitableToHoist(Instruction *I) = 0;
@@ -2406,10 +2407,10 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
   }
   InstructionCost getScalingFactorCost(Type *Ty, GlobalValue *BaseGV,
                                        int64_t BaseOffset, bool HasBaseReg,
-                                       int64_t Scale,
-                                       unsigned AddrSpace) override {
+                                       int64_t Scale, unsigned AddrSpace,
+                                       int64_t ScalableOffset) override {
     return Impl.getScalingFactorCost(Ty, BaseGV, BaseOffset, HasBaseReg, Scale,
-                                     AddrSpace);
+                                     AddrSpace, ScalableOffset);
   }
   bool LSRWithInstrQueries() override { return Impl.LSRWithInstrQueries(); }
   bool isTruncateFree(Type *Ty1, Type *Ty2) override {
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 63c2ef8912b29c..72c7b805abbb67 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -327,11 +327,11 @@ class TargetTransformInfoImplBase {
 
   InstructionCost getScalingFactorCost(Type *Ty, GlobalValue *BaseGV,
                                        int64_t BaseOffset, bool HasBaseReg,
-                                       int64_t Scale,
-                                       unsigned AddrSpace) const {
+                                       int64_t Scale, unsigned AddrSpace,
+                                       int64_t ScalableOffset) const {
     // Guess that all legal addressing mode are free.
     if (isLegalAddressingMode(Ty, BaseGV, BaseOffset, HasBaseReg, Scale,
-                              AddrSpace))
+                              AddrSpace, /*I=*/nullptr, ScalableOffset))
       return 0;
     return -1;
   }
diff --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
index 42d8f74fd427fb..7f42e239d85d96 100644
--- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h
+++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
@@ -406,12 +406,14 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
 
   InstructionCost getScalingFactorCost(Type *Ty, GlobalValue *BaseGV,
                                        int64_t BaseOffset, bool HasBaseReg,
-                                       int64_t Scale, unsigned AddrSpace) {
+                                       int64_t Scale, unsigned AddrSpace,
+                                       int64_t ScalableOffset) {
     TargetLoweringBase::AddrMode AM;
     AM.BaseGV = BaseGV;
     AM.BaseOffs = BaseOffset;
     AM.HasBaseReg = HasBaseReg;
     AM.Scale = Scale;
+    AM.ScalableOffset = ScalableOffset;
     if (getTLI()->isLegalAddressingMode(DL, AM, Ty, AddrSpace))
       return 0;
     return -1;
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index 5f933b4587843c..d00ab62bad9fad 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -532,9 +532,9 @@ bool TargetTransformInfo::prefersVectorizedAddressing() const {
 
 InstructionCost TargetTransformInfo::getScalingFactorCost(
     Type *Ty, GlobalValue *BaseGV, int64_t BaseOffset, bool HasBaseReg,
-    int64_t Scale, unsigned AddrSpace) const {
+    int64_t Scale, unsigned AddrSpace, int64_t ScalableOffset) const {
   InstructionCost Cost = TTIImpl->getScalingFactorCost(
-      Ty, BaseGV, BaseOffset, HasBaseReg, Scale, AddrSpace);
+      Ty, BaseGV, BaseOffset, HasBaseReg, Scale, AddrSpace, ScalableOffset);
   assert(Cost >= 0 && "TTI should not produce negative costs!");
   return Cost;
 }
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index ee7137b92445bb..2b75f0ea2d4d6f 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -4118,10 +4118,9 @@ bool AArch64TTIImpl::preferPredicateOverEpilogue(TailFoldingInfo *TFI) {
   return NumInsns >= SVETailFoldInsnThreshold;
 }
 
-InstructionCost
-AArch64TTIImpl::getScalingFactorCost(Type *Ty, GlobalValue *BaseGV,
-                                     int64_t BaseOffset, bool HasBaseReg,
-                                     int64_t Scale, unsigned AddrSpace) const {
+InstructionCost AArch64TTIImpl::getScalingFactorCost(
+    Type *Ty, GlobalValue *BaseGV, int64_t BaseOffset, bool HasBaseReg,
+    int64_t Scale, unsigned AddrSpace, int64_t ScalableOffset) const {
   // Scaling factors are not free at all.
   // Operands                     | Rt Latency
   // -------------------------------------------
@@ -4134,6 +4133,7 @@ AArch64TTIImpl::getScalingFactorCost(Type *Ty, GlobalValue *BaseGV,
   AM.BaseOffs = BaseOffset;
   AM.HasBaseReg = HasBaseReg;
   AM.Scale = Scale;
+  AM.ScalableOffset = ScalableOffset;
   if (getTLI()->isLegalAddressingMode(DL, AM, Ty, AddrSpace))
     // Scale represents reg2 * scale, thus account for 1 if
     // it is not equal to 0 or 1.
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
index de39dea2be43e1..0f7315446c70d4 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
@@ -407,7 +407,8 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
   /// If the AM is not supported, it returns a negative value.
   InstructionCost getScalingFactorCost(Type *Ty, GlobalValue *BaseGV,
                                        int64_t BaseOffset, bool HasBaseReg,
-                                       int64_t Scale, unsigned AddrSpace) const;
+                                       int64_t Scale, unsigned AddrSpace,
+                                       int64_t ScalableOffset) const;
   /// @}
 
   bool enableSelectOptimize() { return ST->enableSelectOptimize(); }
diff --git a/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp b/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
index 3be894ad3bef2c..73e47fbea23057 100644
--- a/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
+++ b/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
@@ -2572,12 +2572,14 @@ bool ARMTTIImpl::preferPredicatedReductionSelect(
 InstructionCost ARMTTIImpl::getScalingFactorCost(Type *Ty, GlobalValue *BaseGV,
                                                  int64_t BaseOffset,
                                                  bool HasBaseReg, int64_t Scale,
-                                                 unsigned AddrSpace) const {
+                                                 unsigned AddrSpace,
+                                                 int64_t ScalableOffset) const {
   TargetLoweringBase::AddrMode AM;
   AM.BaseGV = BaseGV;
   AM.BaseOffs = BaseOffset;
   AM.HasBaseReg = HasBaseReg;
   AM.Scale = Scale;
+  AM.ScalableOffset = ScalableOffset;
   if (getTLI()->isLegalAddressingMode(DL, AM, Ty, AddrSpace)) {
     if (ST->hasFPAO())
       return AM.Scale < 0 ? 1 : 0; // positive offsets execute faster
diff --git a/llvm/lib/Target/ARM/ARMTargetTransformInfo.h b/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
index bb4b321b530091..10e4b2977a563a 100644
--- a/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
+++ b/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
@@ -303,7 +303,8 @@ class ARMTTIImpl : public BasicTTIImplBase<ARMTTIImpl> {
   /// If the AM is not supported, the return value must be negative.
   InstructionCost getScalingFactorCost(Type *Ty, GlobalValue *BaseGV,
                                        int64_t BaseOffset, bool HasBaseReg,
-                                       int64_t Scale, unsigned AddrSpace) const;
+                                       int64_t Scale, unsigned AddrSpace,
+                                       int64_t ScalableOffset) const;
 
   bool maybeLoweredToCall(Instruction &I);
   bool isLoweredToCall(const Function *F);
diff --git a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
index 5d1810b5bc2c6f..0cfa1da2ce7d78 100644
--- a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
+++ b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
@@ -6670,7 +6670,8 @@ InstructionCost X86TTIImpl::getInterleavedMemoryOpCost(
 InstructionCost X86TTIImpl::getScalingFactorCost(Type *Ty, GlobalValue *BaseGV,
                                                  int64_t BaseOffset,
                                                  bool HasBaseReg, int64_t Scale,
-                                                 unsigned AddrSpace) const {
+                                                 unsigned AddrSpace,
+                                                 int64_t ScalableOffset) const {
   // Scaling factors are not free at all.
   // An indexed folded instruction, i.e., inst (reg1, reg2, scale),
   // will take 2 allocations in the out of order engine instead of 1
diff --git a/llvm/lib/Target/X86/X86TargetTransformInfo.h b/llvm/lib/Target/X86/X86TargetTransformInfo.h
index 985b00438ce878..060b2b98b341da 100644
--- a/llvm/lib/Target/X86/X86TargetTransformInfo.h
+++ b/llvm/lib/Target/X86/X86TargetTransformInfo.h
@@ -253,7 +253,8 @@ class X86TTIImpl : public BasicTTIImplBase<X86TTIImpl> {
   /// If the AM is not supported, it returns a negative value.
   InstructionCost getScalingFactorCost(Type *Ty, GlobalValue *BaseGV,
                                        int64_t BaseOffset, bool HasBaseReg,
-                                       int64_t Scale, unsigned AddrSpace) const;
+                                       int64_t Scale, unsigned AddrSpace,
+                                       int64_t ScalableOffset) const;
 
   bool isLSRCostLess(const TargetTransformInfo::LSRCost &C1,
                      const TargetTransformInfo::LSRCost &C2);

``````````

</details>


https://github.com/llvm/llvm-project/pull/88113


More information about the llvm-commits mailing list