[llvm] [TTI] Support scalable offsets in getScalingFactorCost (PR #88113)

Graham Hunter via llvm-commits llvm-commits at lists.llvm.org
Tue Apr 23 07:35:55 PDT 2024


https://github.com/huntergr-arm updated https://github.com/llvm/llvm-project/pull/88113

>From 49ee8d73abc77c2fc646e6c6d11a2caa38fe80a9 Mon Sep 17 00:00:00 2001
From: Graham Hunter <graham.hunter at arm.com>
Date: Wed, 3 Apr 2024 15:09:36 +0100
Subject: [PATCH 1/2] [TTI] Support scalable offsets in getScalingFactorCost

---
 llvm/include/llvm/Analysis/TargetTransformInfo.h    | 13 +++++++------
 .../include/llvm/Analysis/TargetTransformInfoImpl.h |  6 +++---
 llvm/include/llvm/CodeGen/BasicTTIImpl.h            |  4 +++-
 llvm/lib/Analysis/TargetTransformInfo.cpp           |  4 ++--
 .../Target/AArch64/AArch64TargetTransformInfo.cpp   |  8 ++++----
 .../lib/Target/AArch64/AArch64TargetTransformInfo.h |  3 ++-
 llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp      |  4 +++-
 llvm/lib/Target/ARM/ARMTargetTransformInfo.h        |  3 ++-
 llvm/lib/Target/X86/X86TargetTransformInfo.cpp      |  3 ++-
 llvm/lib/Target/X86/X86TargetTransformInfo.h        |  3 ++-
 10 files changed, 30 insertions(+), 21 deletions(-)

diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 58c69ac939763a..2dd22c84036909 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -835,8 +835,8 @@ class TargetTransformInfo {
   /// TODO: Handle pre/postinc as well.
   InstructionCost getScalingFactorCost(Type *Ty, GlobalValue *BaseGV,
                                        int64_t BaseOffset, bool HasBaseReg,
-                                       int64_t Scale,
-                                       unsigned AddrSpace = 0) const;
+                                       int64_t Scale, unsigned AddrSpace = 0,
+                                       int64_t ScalableOffset = 0) const;
 
   /// Return true if the loop strength reduce pass should make
   /// Instruction* based TTI queries to isLegalAddressingMode(). This is
@@ -1893,7 +1893,8 @@ class TargetTransformInfo::Concept {
   virtual InstructionCost getScalingFactorCost(Type *Ty, GlobalValue *BaseGV,
                                                int64_t BaseOffset,
                                                bool HasBaseReg, int64_t Scale,
-                                               unsigned AddrSpace) = 0;
+                                               unsigned AddrSpace,
+                                               int64_t ScalableOffset) = 0;
   virtual bool LSRWithInstrQueries() = 0;
   virtual bool isTruncateFree(Type *Ty1, Type *Ty2) = 0;
   virtual bool isProfitableToHoist(Instruction *I) = 0;
@@ -2404,10 +2405,10 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
   }
   InstructionCost getScalingFactorCost(Type *Ty, GlobalValue *BaseGV,
                                        int64_t BaseOffset, bool HasBaseReg,
-                                       int64_t Scale,
-                                       unsigned AddrSpace) override {
+                                       int64_t Scale, unsigned AddrSpace,
+                                       int64_t ScalableOffset) override {
     return Impl.getScalingFactorCost(Ty, BaseGV, BaseOffset, HasBaseReg, Scale,
-                                     AddrSpace);
+                                     AddrSpace, ScalableOffset);
   }
   bool LSRWithInstrQueries() override { return Impl.LSRWithInstrQueries(); }
   bool isTruncateFree(Type *Ty1, Type *Ty2) override {
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 4d5cd963e09267..43186aa426e42e 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -327,11 +327,11 @@ class TargetTransformInfoImplBase {
 
   InstructionCost getScalingFactorCost(Type *Ty, GlobalValue *BaseGV,
                                        int64_t BaseOffset, bool HasBaseReg,
-                                       int64_t Scale,
-                                       unsigned AddrSpace) const {
+                                       int64_t Scale, unsigned AddrSpace,
+                                       int64_t ScalableOffset) const {
     // Guess that all legal addressing mode are free.
     if (isLegalAddressingMode(Ty, BaseGV, BaseOffset, HasBaseReg, Scale,
-                              AddrSpace))
+                              AddrSpace, /*I=*/nullptr, ScalableOffset))
       return 0;
     return -1;
   }
diff --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
index 06a19c75cf873a..df5ab35c5ebc56 100644
--- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h
+++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
@@ -404,12 +404,14 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
 
   InstructionCost getScalingFactorCost(Type *Ty, GlobalValue *BaseGV,
                                        int64_t BaseOffset, bool HasBaseReg,
-                                       int64_t Scale, unsigned AddrSpace) {
+                                       int64_t Scale, unsigned AddrSpace,
+                                       int64_t ScalableOffset) {
     TargetLoweringBase::AddrMode AM;
     AM.BaseGV = BaseGV;
     AM.BaseOffs = BaseOffset;
     AM.HasBaseReg = HasBaseReg;
     AM.Scale = Scale;
+    AM.ScalableOffset = ScalableOffset;
     if (getTLI()->isLegalAddressingMode(DL, AM, Ty, AddrSpace))
       return 0;
     return -1;
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index 33c899fe889990..4c1be47a2b70ed 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -532,9 +532,9 @@ bool TargetTransformInfo::prefersVectorizedAddressing() const {
 
 InstructionCost TargetTransformInfo::getScalingFactorCost(
     Type *Ty, GlobalValue *BaseGV, int64_t BaseOffset, bool HasBaseReg,
-    int64_t Scale, unsigned AddrSpace) const {
+    int64_t Scale, unsigned AddrSpace, int64_t ScalableOffset) const {
   InstructionCost Cost = TTIImpl->getScalingFactorCost(
-      Ty, BaseGV, BaseOffset, HasBaseReg, Scale, AddrSpace);
+      Ty, BaseGV, BaseOffset, HasBaseReg, Scale, AddrSpace, ScalableOffset);
   assert(Cost >= 0 && "TTI should not produce negative costs!");
   return Cost;
 }
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index 700242b88346cb..de3eba0fa7de87 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -4160,10 +4160,9 @@ bool AArch64TTIImpl::preferPredicateOverEpilogue(TailFoldingInfo *TFI) {
   return NumInsns >= SVETailFoldInsnThreshold;
 }
 
-InstructionCost
-AArch64TTIImpl::getScalingFactorCost(Type *Ty, GlobalValue *BaseGV,
-                                     int64_t BaseOffset, bool HasBaseReg,
-                                     int64_t Scale, unsigned AddrSpace) const {
+InstructionCost AArch64TTIImpl::getScalingFactorCost(
+    Type *Ty, GlobalValue *BaseGV, int64_t BaseOffset, bool HasBaseReg,
+    int64_t Scale, unsigned AddrSpace, int64_t ScalableOffset) const {
   // Scaling factors are not free at all.
   // Operands                     | Rt Latency
   // -------------------------------------------
@@ -4176,6 +4175,7 @@ AArch64TTIImpl::getScalingFactorCost(Type *Ty, GlobalValue *BaseGV,
   AM.BaseOffs = BaseOffset;
   AM.HasBaseReg = HasBaseReg;
   AM.Scale = Scale;
+  AM.ScalableOffset = ScalableOffset;
   if (getTLI()->isLegalAddressingMode(DL, AM, Ty, AddrSpace))
     // Scale represents reg2 * scale, thus account for 1 if
     // it is not equal to 0 or 1.
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
index dba384481f6a34..9ce354f861f6cc 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
@@ -408,7 +408,8 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
   /// If the AM is not supported, it returns a negative value.
   InstructionCost getScalingFactorCost(Type *Ty, GlobalValue *BaseGV,
                                        int64_t BaseOffset, bool HasBaseReg,
-                                       int64_t Scale, unsigned AddrSpace) const;
+                                       int64_t Scale, unsigned AddrSpace,
+                                       int64_t ScalableOffset) const;
   /// @}
 
   bool enableSelectOptimize() { return ST->enableSelectOptimize(); }
diff --git a/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp b/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
index ee87f7f0e555ef..24c302d4440802 100644
--- a/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
+++ b/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
@@ -2573,12 +2573,14 @@ bool ARMTTIImpl::preferPredicatedReductionSelect(
 InstructionCost ARMTTIImpl::getScalingFactorCost(Type *Ty, GlobalValue *BaseGV,
                                                  int64_t BaseOffset,
                                                  bool HasBaseReg, int64_t Scale,
-                                                 unsigned AddrSpace) const {
+                                                 unsigned AddrSpace,
+                                                 int64_t ScalableOffset) const {
   TargetLoweringBase::AddrMode AM;
   AM.BaseGV = BaseGV;
   AM.BaseOffs = BaseOffset;
   AM.HasBaseReg = HasBaseReg;
   AM.Scale = Scale;
+  AM.ScalableOffset = ScalableOffset;
   if (getTLI()->isLegalAddressingMode(DL, AM, Ty, AddrSpace)) {
     if (ST->hasFPAO())
       return AM.Scale < 0 ? 1 : 0; // positive offsets execute faster
diff --git a/llvm/lib/Target/ARM/ARMTargetTransformInfo.h b/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
index 04b32194f806f6..70fc21c6dcf348 100644
--- a/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
+++ b/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
@@ -304,7 +304,8 @@ class ARMTTIImpl : public BasicTTIImplBase<ARMTTIImpl> {
   /// If the AM is not supported, the return value must be negative.
   InstructionCost getScalingFactorCost(Type *Ty, GlobalValue *BaseGV,
                                        int64_t BaseOffset, bool HasBaseReg,
-                                       int64_t Scale, unsigned AddrSpace) const;
+                                       int64_t Scale, unsigned AddrSpace,
+                                       int64_t ScalableOffset) const;
 
   bool maybeLoweredToCall(Instruction &I);
   bool isLoweredToCall(const Function *F);
diff --git a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
index aac355713f90f6..af56d714fa0a71 100644
--- a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
+++ b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
@@ -6710,7 +6710,8 @@ InstructionCost X86TTIImpl::getInterleavedMemoryOpCost(
 InstructionCost X86TTIImpl::getScalingFactorCost(Type *Ty, GlobalValue *BaseGV,
                                                  int64_t BaseOffset,
                                                  bool HasBaseReg, int64_t Scale,
-                                                 unsigned AddrSpace) const {
+                                                 unsigned AddrSpace,
+                                                 int64_t ScalableOffset) const {
   // Scaling factors are not free at all.
   // An indexed folded instruction, i.e., inst (reg1, reg2, scale),
   // will take 2 allocations in the out of order engine instead of 1
diff --git a/llvm/lib/Target/X86/X86TargetTransformInfo.h b/llvm/lib/Target/X86/X86TargetTransformInfo.h
index 8ef9b4f86ffd7c..013d2520d5a749 100644
--- a/llvm/lib/Target/X86/X86TargetTransformInfo.h
+++ b/llvm/lib/Target/X86/X86TargetTransformInfo.h
@@ -254,7 +254,8 @@ class X86TTIImpl : public BasicTTIImplBase<X86TTIImpl> {
   /// If the AM is not supported, it returns a negative value.
   InstructionCost getScalingFactorCost(Type *Ty, GlobalValue *BaseGV,
                                        int64_t BaseOffset, bool HasBaseReg,
-                                       int64_t Scale, unsigned AddrSpace) const;
+                                       int64_t Scale, unsigned AddrSpace,
+                                       int64_t ScalableOffset) const;
 
   bool isLSRCostLess(const TargetTransformInfo::LSRCost &C1,
                      const TargetTransformInfo::LSRCost &C2);

>From 2a61c4dc4f1b63fc296f0030382ba73b83cb7cc7 Mon Sep 17 00:00:00 2001
From: Graham Hunter <graham.hunter at arm.com>
Date: Tue, 23 Apr 2024 15:30:14 +0100
Subject: [PATCH 2/2] Switch to StackOffset

---
 .../llvm/Analysis/TargetTransformInfo.h       | 19 +++++++++----------
 .../llvm/Analysis/TargetTransformInfoImpl.h   | 13 ++++++++-----
 llvm/include/llvm/CodeGen/BasicTTIImpl.h      |  9 ++++-----
 llvm/lib/Analysis/TargetTransformInfo.cpp     |  6 +++---
 .../AArch64/AArch64TargetTransformInfo.cpp    | 11 ++++++-----
 .../AArch64/AArch64TargetTransformInfo.h      |  5 ++---
 .../lib/Target/ARM/ARMTargetTransformInfo.cpp |  9 ++++-----
 llvm/lib/Target/ARM/ARMTargetTransformInfo.h  |  5 ++---
 .../lib/Target/X86/X86TargetTransformInfo.cpp |  8 ++++----
 llvm/lib/Target/X86/X86TargetTransformInfo.h  |  5 ++---
 .../Transforms/Scalar/LoopStrengthReduce.cpp  |  6 ++++--
 11 files changed, 48 insertions(+), 48 deletions(-)

diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 2dd22c84036909..3329bcddd5bf39 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -834,9 +834,9 @@ class TargetTransformInfo {
   /// If the AM is not supported, it returns a negative value.
   /// TODO: Handle pre/postinc as well.
   InstructionCost getScalingFactorCost(Type *Ty, GlobalValue *BaseGV,
-                                       int64_t BaseOffset, bool HasBaseReg,
-                                       int64_t Scale, unsigned AddrSpace = 0,
-                                       int64_t ScalableOffset = 0) const;
+                                       StackOffset BaseOffset, bool HasBaseReg,
+                                       int64_t Scale,
+                                       unsigned AddrSpace = 0) const;
 
   /// Return true if the loop strength reduce pass should make
   /// Instruction* based TTI queries to isLegalAddressingMode(). This is
@@ -1891,10 +1891,9 @@ class TargetTransformInfo::Concept {
   virtual bool hasVolatileVariant(Instruction *I, unsigned AddrSpace) = 0;
   virtual bool prefersVectorizedAddressing() = 0;
   virtual InstructionCost getScalingFactorCost(Type *Ty, GlobalValue *BaseGV,
-                                               int64_t BaseOffset,
+                                               StackOffset BaseOffset,
                                                bool HasBaseReg, int64_t Scale,
-                                               unsigned AddrSpace,
-                                               int64_t ScalableOffset) = 0;
+                                               unsigned AddrSpace) = 0;
   virtual bool LSRWithInstrQueries() = 0;
   virtual bool isTruncateFree(Type *Ty1, Type *Ty2) = 0;
   virtual bool isProfitableToHoist(Instruction *I) = 0;
@@ -2404,11 +2403,11 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
     return Impl.prefersVectorizedAddressing();
   }
   InstructionCost getScalingFactorCost(Type *Ty, GlobalValue *BaseGV,
-                                       int64_t BaseOffset, bool HasBaseReg,
-                                       int64_t Scale, unsigned AddrSpace,
-                                       int64_t ScalableOffset) override {
+                                       StackOffset BaseOffset, bool HasBaseReg,
+                                       int64_t Scale,
+                                       unsigned AddrSpace) override {
     return Impl.getScalingFactorCost(Ty, BaseGV, BaseOffset, HasBaseReg, Scale,
-                                     AddrSpace, ScalableOffset);
+                                     AddrSpace);
   }
   bool LSRWithInstrQueries() override { return Impl.LSRWithInstrQueries(); }
   bool isTruncateFree(Type *Ty1, Type *Ty2) override {
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 43186aa426e42e..ff857578ecc2e2 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -32,6 +32,8 @@ class Function;
 /// Base class for use as a mix-in that aids implementing
 /// a TargetTransformInfo-compatible class.
 class TargetTransformInfoImplBase {
+  friend class TargetTransformInfo;
+
 protected:
   typedef TargetTransformInfo TTI;
 
@@ -326,12 +328,13 @@ class TargetTransformInfoImplBase {
   bool prefersVectorizedAddressing() const { return true; }
 
   InstructionCost getScalingFactorCost(Type *Ty, GlobalValue *BaseGV,
-                                       int64_t BaseOffset, bool HasBaseReg,
-                                       int64_t Scale, unsigned AddrSpace,
-                                       int64_t ScalableOffset) const {
+                                       StackOffset BaseOffset, bool HasBaseReg,
+                                       int64_t Scale,
+                                       unsigned AddrSpace) const {
     // Guess that all legal addressing mode are free.
-    if (isLegalAddressingMode(Ty, BaseGV, BaseOffset, HasBaseReg, Scale,
-                              AddrSpace, /*I=*/nullptr, ScalableOffset))
+    if (isLegalAddressingMode(Ty, BaseGV, BaseOffset.getFixed(), HasBaseReg,
+                              Scale, AddrSpace, /*I=*/nullptr,
+                              BaseOffset.getScalable()))
       return 0;
     return -1;
   }
diff --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
index df5ab35c5ebc56..3a96ce694b3e0f 100644
--- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h
+++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
@@ -403,15 +403,14 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
   }
 
   InstructionCost getScalingFactorCost(Type *Ty, GlobalValue *BaseGV,
-                                       int64_t BaseOffset, bool HasBaseReg,
-                                       int64_t Scale, unsigned AddrSpace,
-                                       int64_t ScalableOffset) {
+                                       StackOffset BaseOffset, bool HasBaseReg,
+                                       int64_t Scale, unsigned AddrSpace) {
     TargetLoweringBase::AddrMode AM;
     AM.BaseGV = BaseGV;
-    AM.BaseOffs = BaseOffset;
+    AM.BaseOffs = BaseOffset.getFixed();
     AM.HasBaseReg = HasBaseReg;
     AM.Scale = Scale;
-    AM.ScalableOffset = ScalableOffset;
+    AM.ScalableOffset = BaseOffset.getScalable();
     if (getTLI()->isLegalAddressingMode(DL, AM, Ty, AddrSpace))
       return 0;
     return -1;
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index 4c1be47a2b70ed..00443ace46f745 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -531,10 +531,10 @@ bool TargetTransformInfo::prefersVectorizedAddressing() const {
 }
 
 InstructionCost TargetTransformInfo::getScalingFactorCost(
-    Type *Ty, GlobalValue *BaseGV, int64_t BaseOffset, bool HasBaseReg,
-    int64_t Scale, unsigned AddrSpace, int64_t ScalableOffset) const {
+    Type *Ty, GlobalValue *BaseGV, StackOffset BaseOffset, bool HasBaseReg,
+    int64_t Scale, unsigned AddrSpace) const {
   InstructionCost Cost = TTIImpl->getScalingFactorCost(
-      Ty, BaseGV, BaseOffset, HasBaseReg, Scale, AddrSpace, ScalableOffset);
+      Ty, BaseGV, BaseOffset, HasBaseReg, Scale, AddrSpace);
   assert(Cost >= 0 && "TTI should not produce negative costs!");
   return Cost;
 }
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index de3eba0fa7de87..7790dd91b77875 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -4160,9 +4160,10 @@ bool AArch64TTIImpl::preferPredicateOverEpilogue(TailFoldingInfo *TFI) {
   return NumInsns >= SVETailFoldInsnThreshold;
 }
 
-InstructionCost AArch64TTIImpl::getScalingFactorCost(
-    Type *Ty, GlobalValue *BaseGV, int64_t BaseOffset, bool HasBaseReg,
-    int64_t Scale, unsigned AddrSpace, int64_t ScalableOffset) const {
+InstructionCost
+AArch64TTIImpl::getScalingFactorCost(Type *Ty, GlobalValue *BaseGV,
+                                     StackOffset BaseOffset, bool HasBaseReg,
+                                     int64_t Scale, unsigned AddrSpace) const {
   // Scaling factors are not free at all.
   // Operands                     | Rt Latency
   // -------------------------------------------
@@ -4172,10 +4173,10 @@ InstructionCost AArch64TTIImpl::getScalingFactorCost(
   // Rt, [Xn, Wm, <extend> #imm]  |
   TargetLoweringBase::AddrMode AM;
   AM.BaseGV = BaseGV;
-  AM.BaseOffs = BaseOffset;
+  AM.BaseOffs = BaseOffset.getFixed();
   AM.HasBaseReg = HasBaseReg;
   AM.Scale = Scale;
-  AM.ScalableOffset = ScalableOffset;
+  AM.ScalableOffset = BaseOffset.getScalable();
   if (getTLI()->isLegalAddressingMode(DL, AM, Ty, AddrSpace))
     // Scale represents reg2 * scale, thus account for 1 if
     // it is not equal to 0 or 1.
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
index 9ce354f861f6cc..42b09b26a2c5b9 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
@@ -407,9 +407,8 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
   /// If the AM is supported, the return value must be >= 0.
   /// If the AM is not supported, it returns a negative value.
   InstructionCost getScalingFactorCost(Type *Ty, GlobalValue *BaseGV,
-                                       int64_t BaseOffset, bool HasBaseReg,
-                                       int64_t Scale, unsigned AddrSpace,
-                                       int64_t ScalableOffset) const;
+                                       StackOffset BaseOffset, bool HasBaseReg,
+                                       int64_t Scale, unsigned AddrSpace) const;
   /// @}
 
   bool enableSelectOptimize() { return ST->enableSelectOptimize(); }
diff --git a/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp b/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
index 24c302d4440802..5868438fcc3676 100644
--- a/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
+++ b/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
@@ -2571,16 +2571,15 @@ bool ARMTTIImpl::preferPredicatedReductionSelect(
 }
 
 InstructionCost ARMTTIImpl::getScalingFactorCost(Type *Ty, GlobalValue *BaseGV,
-                                                 int64_t BaseOffset,
+                                                 StackOffset BaseOffset,
                                                  bool HasBaseReg, int64_t Scale,
-                                                 unsigned AddrSpace,
-                                                 int64_t ScalableOffset) const {
+                                                 unsigned AddrSpace) const {
   TargetLoweringBase::AddrMode AM;
   AM.BaseGV = BaseGV;
-  AM.BaseOffs = BaseOffset;
+  AM.BaseOffs = BaseOffset.getFixed();
   AM.HasBaseReg = HasBaseReg;
   AM.Scale = Scale;
-  AM.ScalableOffset = ScalableOffset;
+  assert(!BaseOffset.getScalable() && "Scalable offsets unsupported");
   if (getTLI()->isLegalAddressingMode(DL, AM, Ty, AddrSpace)) {
     if (ST->hasFPAO())
       return AM.Scale < 0 ? 1 : 0; // positive offsets execute faster
diff --git a/llvm/lib/Target/ARM/ARMTargetTransformInfo.h b/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
index 70fc21c6dcf348..9ea1404cc2fdc5 100644
--- a/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
+++ b/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
@@ -303,9 +303,8 @@ class ARMTTIImpl : public BasicTTIImplBase<ARMTTIImpl> {
   /// If the AM is supported, the return value must be >= 0.
   /// If the AM is not supported, the return value must be negative.
   InstructionCost getScalingFactorCost(Type *Ty, GlobalValue *BaseGV,
-                                       int64_t BaseOffset, bool HasBaseReg,
-                                       int64_t Scale, unsigned AddrSpace,
-                                       int64_t ScalableOffset) const;
+                                       StackOffset BaseOffset, bool HasBaseReg,
+                                       int64_t Scale, unsigned AddrSpace) const;
 
   bool maybeLoweredToCall(Instruction &I);
   bool isLoweredToCall(const Function *F);
diff --git a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
index af56d714fa0a71..13d8a3474c5522 100644
--- a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
+++ b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
@@ -6708,10 +6708,9 @@ InstructionCost X86TTIImpl::getInterleavedMemoryOpCost(
 }
 
 InstructionCost X86TTIImpl::getScalingFactorCost(Type *Ty, GlobalValue *BaseGV,
-                                                 int64_t BaseOffset,
+                                                 StackOffset BaseOffset,
                                                  bool HasBaseReg, int64_t Scale,
-                                                 unsigned AddrSpace,
-                                                 int64_t ScalableOffset) const {
+                                                 unsigned AddrSpace) const {
   // Scaling factors are not free at all.
   // An indexed folded instruction, i.e., inst (reg1, reg2, scale),
   // will take 2 allocations in the out of order engine instead of 1
@@ -6732,9 +6731,10 @@ InstructionCost X86TTIImpl::getScalingFactorCost(Type *Ty, GlobalValue *BaseGV,
   // vmovaps %ymm1, (%r8) can use port 2, 3, or 7.
   TargetLoweringBase::AddrMode AM;
   AM.BaseGV = BaseGV;
-  AM.BaseOffs = BaseOffset;
+  AM.BaseOffs = BaseOffset.getFixed();
   AM.HasBaseReg = HasBaseReg;
   AM.Scale = Scale;
+  assert(!BaseOffset.getScalable() && "Scalable offsets unsupported");
   if (getTLI()->isLegalAddressingMode(DL, AM, Ty, AddrSpace))
     // Scale represents reg2 * scale, thus account for 1
     // as soon as we use a second register.
diff --git a/llvm/lib/Target/X86/X86TargetTransformInfo.h b/llvm/lib/Target/X86/X86TargetTransformInfo.h
index 013d2520d5a749..0a56112c6e771d 100644
--- a/llvm/lib/Target/X86/X86TargetTransformInfo.h
+++ b/llvm/lib/Target/X86/X86TargetTransformInfo.h
@@ -253,9 +253,8 @@ class X86TTIImpl : public BasicTTIImplBase<X86TTIImpl> {
   /// If the AM is supported, the return value must be >= 0.
   /// If the AM is not supported, it returns a negative value.
   InstructionCost getScalingFactorCost(Type *Ty, GlobalValue *BaseGV,
-                                       int64_t BaseOffset, bool HasBaseReg,
-                                       int64_t Scale, unsigned AddrSpace,
-                                       int64_t ScalableOffset) const;
+                                       StackOffset BaseOffset, bool HasBaseReg,
+                                       int64_t Scale, unsigned AddrSpace) const;
 
   bool isLSRCostLess(const TargetTransformInfo::LSRCost &C1,
                      const TargetTransformInfo::LSRCost &C2);
diff --git a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp
index ec42e2d6e193a6..eb1904ccaff352 100644
--- a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp
@@ -1817,10 +1817,12 @@ static InstructionCost getScalingFactorCost(const TargetTransformInfo &TTI,
   case LSRUse::Address: {
     // Check the scaling factor cost with both the min and max offsets.
     InstructionCost ScaleCostMinOffset = TTI.getScalingFactorCost(
-        LU.AccessTy.MemTy, F.BaseGV, F.BaseOffset + LU.MinOffset, F.HasBaseReg,
+        LU.AccessTy.MemTy, F.BaseGV,
+        StackOffset::getFixed(F.BaseOffset + LU.MinOffset), F.HasBaseReg,
         F.Scale, LU.AccessTy.AddrSpace);
     InstructionCost ScaleCostMaxOffset = TTI.getScalingFactorCost(
-        LU.AccessTy.MemTy, F.BaseGV, F.BaseOffset + LU.MaxOffset, F.HasBaseReg,
+        LU.AccessTy.MemTy, F.BaseGV,
+        StackOffset::getFixed(F.BaseOffset + LU.MaxOffset), F.HasBaseReg,
         F.Scale, LU.AccessTy.AddrSpace);
 
     assert(ScaleCostMinOffset.isValid() && ScaleCostMaxOffset.isValid() &&



More information about the llvm-commits mailing list