[llvm] [AArch64] Support scalable offsets with isLegalAddressingMode (PR #83255)

Graham Hunter via llvm-commits llvm-commits at lists.llvm.org
Wed Mar 13 06:59:47 PDT 2024


https://github.com/huntergr-arm updated https://github.com/llvm/llvm-project/pull/83255

>From 04354d540b01fbfea9d5b78aa5fcad790f890a12 Mon Sep 17 00:00:00 2001
From: Graham Hunter <graham.hunter at arm.com>
Date: Fri, 23 Feb 2024 14:02:13 +0000
Subject: [PATCH 1/3] [TTI][TLI][NFC] Add 'OffsetIsScalable' to
 isLegalAddressingMode

Adds a new parameter to the TTI version of the function, along with
a matching field in the struct for TLI.

This extra bool just indicates that the BaseOffset should be treated
as a scalable quantity (meaning that it should be multiplied by
'vscale' to get the real value at runtime).
---
 .../llvm/Analysis/TargetTransformInfo.h       | 11 ++--
 .../llvm/Analysis/TargetTransformInfoImpl.h   |  3 +-
 llvm/include/llvm/CodeGen/BasicTTIImpl.h      |  6 ++-
 llvm/include/llvm/CodeGen/TargetLowering.h    |  1 +
 llvm/lib/Analysis/TargetTransformInfo.cpp     |  5 +-
 llvm/lib/CodeGen/TargetLoweringBase.cpp       |  4 ++
 .../Target/AArch64/AArch64ISelLowering.cpp    |  4 +-
 .../Target/AArch64/AddressingModes.cpp        | 50 ++++++++++++++++++-
 8 files changed, 71 insertions(+), 13 deletions(-)

diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 4eab357f1b33b6..fd9bbd6441f981 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -710,8 +710,8 @@ class TargetTransformInfo {
   /// TODO: Handle pre/postinc as well.
   bool isLegalAddressingMode(Type *Ty, GlobalValue *BaseGV, int64_t BaseOffset,
                              bool HasBaseReg, int64_t Scale,
-                             unsigned AddrSpace = 0,
-                             Instruction *I = nullptr) const;
+                             unsigned AddrSpace = 0, Instruction *I = nullptr,
+                             int64_t ScalableOffset = 0) const;
 
   /// Return true if LSR cost of C1 is lower than C2.
   bool isLSRCostLess(const TargetTransformInfo::LSRCost &C1,
@@ -1839,7 +1839,8 @@ class TargetTransformInfo::Concept {
   virtual bool isLegalAddressingMode(Type *Ty, GlobalValue *BaseGV,
                                      int64_t BaseOffset, bool HasBaseReg,
                                      int64_t Scale, unsigned AddrSpace,
-                                     Instruction *I) = 0;
+                                     Instruction *I,
+                                     int64_t ScalableOffset) = 0;
   virtual bool isLSRCostLess(const TargetTransformInfo::LSRCost &C1,
                              const TargetTransformInfo::LSRCost &C2) = 0;
   virtual bool isNumRegsMajorCostOfLSR() = 0;
@@ -2300,9 +2301,9 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
   }
   bool isLegalAddressingMode(Type *Ty, GlobalValue *BaseGV, int64_t BaseOffset,
                              bool HasBaseReg, int64_t Scale, unsigned AddrSpace,
-                             Instruction *I) override {
+                             Instruction *I, int64_t ScalableOffset) override {
     return Impl.isLegalAddressingMode(Ty, BaseGV, BaseOffset, HasBaseReg, Scale,
-                                      AddrSpace, I);
+                                      AddrSpace, I, ScalableOffset);
   }
   bool isLSRCostLess(const TargetTransformInfo::LSRCost &C1,
                      const TargetTransformInfo::LSRCost &C2) override {
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 7f661bb4a1df20..07eeceeeaa22a8 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -220,7 +220,8 @@ class TargetTransformInfoImplBase {
 
   bool isLegalAddressingMode(Type *Ty, GlobalValue *BaseGV, int64_t BaseOffset,
                              bool HasBaseReg, int64_t Scale, unsigned AddrSpace,
-                             Instruction *I = nullptr) const {
+                             Instruction *I = nullptr,
+                             int64_t ScalableOffset = 0) const {
     // Guess that only reg and reg+reg addressing is allowed. This heuristic is
     // taken from the implementation of LSR.
     return !BaseGV && BaseOffset == 0 && (Scale == 0 || Scale == 1);
diff --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
index 61f6564e8cd79b..721900038ddd57 100644
--- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h
+++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
@@ -333,13 +333,15 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
   }
 
   bool isLegalAddressingMode(Type *Ty, GlobalValue *BaseGV, int64_t BaseOffset,
-                             bool HasBaseReg, int64_t Scale,
-                             unsigned AddrSpace, Instruction *I = nullptr) {
+                             bool HasBaseReg, int64_t Scale, unsigned AddrSpace,
+                             Instruction *I = nullptr,
+                             int64_t ScalableOffset = 0) {
     TargetLoweringBase::AddrMode AM;
     AM.BaseGV = BaseGV;
     AM.BaseOffs = BaseOffset;
     AM.HasBaseReg = HasBaseReg;
     AM.Scale = Scale;
+    AM.ScalableOffset = ScalableOffset;
     return getTLI()->isLegalAddressingMode(DL, AM, Ty, AddrSpace, I);
   }
 
diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index 2f164a460db843..629949d2e76b0e 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -2733,6 +2733,7 @@ class TargetLoweringBase {
     int64_t      BaseOffs = 0;
     bool         HasBaseReg = false;
     int64_t      Scale = 0;
+    int64_t ScalableOffset = 0;
     AddrMode() = default;
   };
 
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index 15311be4dba277..4b113e6d3798cd 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -403,9 +403,10 @@ bool TargetTransformInfo::isLegalAddressingMode(Type *Ty, GlobalValue *BaseGV,
                                                 int64_t BaseOffset,
                                                 bool HasBaseReg, int64_t Scale,
                                                 unsigned AddrSpace,
-                                                Instruction *I) const {
+                                                Instruction *I,
+                                                int64_t ScalableOffset) const {
   return TTIImpl->isLegalAddressingMode(Ty, BaseGV, BaseOffset, HasBaseReg,
-                                        Scale, AddrSpace, I);
+                                        Scale, AddrSpace, I, ScalableOffset);
 }
 
 bool TargetTransformInfo::isLSRCostLess(const LSRCost &C1,
diff --git a/llvm/lib/CodeGen/TargetLoweringBase.cpp b/llvm/lib/CodeGen/TargetLoweringBase.cpp
index 8ac55ee6a5d0c1..9990556f89ed8b 100644
--- a/llvm/lib/CodeGen/TargetLoweringBase.cpp
+++ b/llvm/lib/CodeGen/TargetLoweringBase.cpp
@@ -2011,6 +2011,10 @@ bool TargetLoweringBase::isLegalAddressingMode(const DataLayout &DL,
   // The default implementation of this implements a conservative RISCy, r+r and
   // r+i addr mode.
 
+  // Scalable offsets not supported
+  if (AM.ScalableOffset)
+    return false;
+
   // Allows a sign-extended 16-bit immediate field.
   if (AM.BaseOffs <= -(1LL << 16) || AM.BaseOffs >= (1LL << 16)-1)
     return false;
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 5b7a36d2eba76f..23adf595ef3264 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -16673,11 +16673,11 @@ bool AArch64TargetLowering::isLegalAddressingMode(const DataLayout &DL,
     if (isa<ScalableVectorType>(Ty)) {
       uint64_t VecElemNumBytes =
           DL.getTypeSizeInBits(cast<VectorType>(Ty)->getElementType()) / 8;
-      return AM.HasBaseReg && !AM.BaseOffs &&
+      return AM.HasBaseReg && !AM.BaseOffs && !AM.ScalableOffset &&
              (AM.Scale == 0 || (uint64_t)AM.Scale == VecElemNumBytes);
     }
 
-    return AM.HasBaseReg && !AM.BaseOffs && !AM.Scale;
+    return AM.HasBaseReg && !AM.BaseOffs && !AM.ScalableOffset && !AM.Scale;
   }
 
   // check reg + imm case:
diff --git a/llvm/unittests/Target/AArch64/AddressingModes.cpp b/llvm/unittests/Target/AArch64/AddressingModes.cpp
index 284ea7ae9233ed..30dbbf2fe8a757 100644
--- a/llvm/unittests/Target/AArch64/AddressingModes.cpp
+++ b/llvm/unittests/Target/AArch64/AddressingModes.cpp
@@ -13,11 +13,13 @@ using namespace llvm;
 namespace {
 
 struct AddrMode : public TargetLowering::AddrMode {
-  constexpr AddrMode(GlobalValue *GV, int64_t Offs, bool HasBase, int64_t S) {
+  constexpr AddrMode(GlobalValue *GV, int64_t Offs, bool HasBase, int64_t S,
+                     int64_t SOffs = 0) {
     BaseGV = GV;
     BaseOffs = Offs;
     HasBaseReg = HasBase;
     Scale = S;
+    ScalableOffset = SOffs;
   }
 };
 struct TestCase {
@@ -153,6 +155,45 @@ const std::initializer_list<TestCase> Tests = {
     {{nullptr, 4096 + 1, true, 0}, 8, false},
 
 };
+
+struct SVETestCase {
+  AddrMode AM;
+  unsigned TypeBits;
+  unsigned NumElts;
+  bool Result;
+};
+
+const std::initializer_list<SVETestCase> SVETests = {
+    // {BaseGV, BaseOffs, HasBaseReg, Scale, SOffs}, EltBits, Count, Result
+    // Test immediate range -- [-8,7] vector's worth.
+    // <vscale x 16 x i8>, increment by one vector
+    {{nullptr, 0, true, 0, 16}, 8, 16, false},
+    // <vscale x 4 x i32>, increment by eight vectors
+    {{nullptr, 0, true, 0, 128}, 32, 4, false},
+    // <vscale x 8 x i16>, increment by seven vectors
+    {{nullptr, 0, true, 0, 112}, 16, 8, false},
+    // <vscale x 2 x i64>, decrement by eight vectors
+    {{nullptr, 0, true, 0, -128}, 64, 2, false},
+    // <vscale x 16 x i8>, decrement by nine vectors
+    {{nullptr, 0, true, 0, -144}, 8, 16, false},
+
+    // Half the size of a vector register, but allowable with extending
+    // loads and truncating stores
+    // <vscale x 8 x i8>, increment by three vectors
+    {{nullptr, 0, true, 0, 24}, 8, 8, false},
+
+    // Test invalid types or offsets
+    // <vscale x 5 x i32>, increment by one vector (base size > 16B)
+    {{nullptr, 0, true, 0, 20}, 32, 5, false},
+    // <vscale x 8 x i16>, increment by half a vector
+    {{nullptr, 0, true, 0, 8}, 16, 8, false},
+    // <vscale x 3 x i8>, increment by 3 vectors (non-power-of-two)
+    {{nullptr, 0, true, 0, 9}, 8, 3, false},
+
+    // Scalable and fixed offsets
+    // <vscale x 16 x i8>, increment by 32 then decrement by vscale x 16
+    {{nullptr, 32, true, 0, -16}, 8, 16, false},
+};
 } // namespace
 
 TEST(AddressingModes, AddressingModes) {
@@ -179,4 +220,11 @@ TEST(AddressingModes, AddressingModes) {
     Type *Typ = Type::getIntNTy(Ctx, Test.TypeBits);
     ASSERT_EQ(TLI->isLegalAddressingMode(DL, Test.AM, Typ, 0), Test.Result);
   }
+
+  for (const auto &SVETest : SVETests) {
+    Type *Ty = VectorType::get(Type::getIntNTy(Ctx, SVETest.TypeBits),
+                               ElementCount::getScalable(SVETest.NumElts));
+    ASSERT_EQ(TLI->isLegalAddressingMode(DL, SVETest.AM, Ty, 0),
+              SVETest.Result);
+  }
 }

>From 68ea630837ec62697e3455cc2b102e17f38b5ae2 Mon Sep 17 00:00:00 2001
From: Graham Hunter <graham.hunter at arm.com>
Date: Fri, 23 Feb 2024 15:49:03 +0000
Subject: [PATCH 2/3] [AArch64] Support scalable offsets with
 isLegalAddressingMode

Given a base register and a scalable offset (multiplied by vscale),
return true if the offset corresponds to the valid range for the
size of the vector type in memory; e.g. `[X0, #1, mul vl]`
---
 llvm/lib/Target/AArch64/AArch64ISelLowering.cpp   | 14 ++++++++++++++
 llvm/unittests/Target/AArch64/AddressingModes.cpp |  8 ++++----
 2 files changed, 18 insertions(+), 4 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 23adf595ef3264..2bfaed146e9b60 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -16671,6 +16671,16 @@ bool AArch64TargetLowering::isLegalAddressingMode(const DataLayout &DL,
 
   if (Ty->isScalableTy()) {
     if (isa<ScalableVectorType>(Ty)) {
+      // See if we have a foldable vscale-based offset, for vector types which
+      // are either legal or smaller than the minimum; more work will be
+      // required if we need to consider addressing for types which need
+      // legalization by splitting.
+      uint64_t VecNumBytes = DL.getTypeSizeInBits(Ty).getKnownMinValue() / 8;
+      if (AM.HasBaseReg && !AM.BaseOffs && AM.ScalableOffset && !AM.Scale &&
+          (AM.ScalableOffset % VecNumBytes == 0) && VecNumBytes <= 16 &&
+          isPowerOf2_64(VecNumBytes))
+        return isInt<4>(AM.ScalableOffset / (int64_t)VecNumBytes);
+
       uint64_t VecElemNumBytes =
           DL.getTypeSizeInBits(cast<VectorType>(Ty)->getElementType()) / 8;
       return AM.HasBaseReg && !AM.BaseOffs && !AM.ScalableOffset &&
@@ -16680,6 +16690,10 @@ bool AArch64TargetLowering::isLegalAddressingMode(const DataLayout &DL,
     return AM.HasBaseReg && !AM.BaseOffs && !AM.ScalableOffset && !AM.Scale;
   }
 
+  // No scalable offsets allowed for non-scalable types.
+  if (AM.ScalableOffset)
+    return false;
+
   // check reg + imm case:
   // i.e., reg + 0, reg + imm9, reg + SIZE_IN_BYTES * uimm12
   uint64_t NumBytes = 0;
diff --git a/llvm/unittests/Target/AArch64/AddressingModes.cpp b/llvm/unittests/Target/AArch64/AddressingModes.cpp
index 30dbbf2fe8a757..0af18d886791a1 100644
--- a/llvm/unittests/Target/AArch64/AddressingModes.cpp
+++ b/llvm/unittests/Target/AArch64/AddressingModes.cpp
@@ -167,20 +167,20 @@ const std::initializer_list<SVETestCase> SVETests = {
     // {BaseGV, BaseOffs, HasBaseReg, Scale, SOffs}, EltBits, Count, Result
     // Test immediate range -- [-8,7] vector's worth.
     // <vscale x 16 x i8>, increment by one vector
-    {{nullptr, 0, true, 0, 16}, 8, 16, false},
+    {{nullptr, 0, true, 0, 16}, 8, 16, true},
     // <vscale x 4 x i32>, increment by eight vectors
     {{nullptr, 0, true, 0, 128}, 32, 4, false},
     // <vscale x 8 x i16>, increment by seven vectors
-    {{nullptr, 0, true, 0, 112}, 16, 8, false},
+    {{nullptr, 0, true, 0, 112}, 16, 8, true},
     // <vscale x 2 x i64>, decrement by eight vectors
-    {{nullptr, 0, true, 0, -128}, 64, 2, false},
+    {{nullptr, 0, true, 0, -128}, 64, 2, true},
     // <vscale x 16 x i8>, decrement by nine vectors
     {{nullptr, 0, true, 0, -144}, 8, 16, false},
 
     // Half the size of a vector register, but allowable with extending
     // loads and truncating stores
     // <vscale x 8 x i8>, increment by three vectors
-    {{nullptr, 0, true, 0, 24}, 8, 8, false},
+    {{nullptr, 0, true, 0, 24}, 8, 8, true},
 
     // Test invalid types or offsets
     // <vscale x 5 x i32>, increment by one vector (base size > 16B)

>From ee298b67080ece2acbf24c323b70b2385a1f6990 Mon Sep 17 00:00:00 2001
From: Graham Hunter <graham.hunter at arm.com>
Date: Wed, 13 Mar 2024 13:31:12 +0000
Subject: [PATCH 3/3] Add some doxygen comments

---
 llvm/include/llvm/Analysis/TargetTransformInfo.h | 4 ++++
 llvm/include/llvm/CodeGen/TargetLowering.h       | 3 ++-
 2 files changed, 6 insertions(+), 1 deletion(-)

diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index fd9bbd6441f981..10e12238251933 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -707,6 +707,10 @@ class TargetTransformInfo {
   /// The type may be VoidTy, in which case only return true if the addressing
   /// mode is legal for a load/store of any legal type.
   /// If target returns true in LSRWithInstrQueries(), I may be valid.
+  /// \param ScalableOffset represents a quantity of bytes multiplied by vscale,
+  /// an invariant value known only at runtime. Most targets should not accept
+  /// a scalable offset.
+  ///
   /// TODO: Handle pre/postinc as well.
   bool isLegalAddressingMode(Type *Ty, GlobalValue *BaseGV, int64_t BaseOffset,
                              bool HasBaseReg, int64_t Scale,
diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index 629949d2e76b0e..4753d8e8a51257 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -2722,12 +2722,13 @@ class TargetLoweringBase {
   }
 
   /// This represents an addressing mode of:
-  ///    BaseGV + BaseOffs + BaseReg + Scale*ScaleReg
+  ///    BaseGV + BaseOffs + BaseReg + Scale*ScaleReg + ScalableOffset*vscale
   /// If BaseGV is null,  there is no BaseGV.
   /// If BaseOffs is zero, there is no base offset.
   /// If HasBaseReg is false, there is no base register.
   /// If Scale is zero, there is no ScaleReg.  Scale of 1 indicates a reg with
   /// no scale.
+  /// If ScalableOffset is zero, there is no scalable offset.
   struct AddrMode {
     GlobalValue *BaseGV = nullptr;
     int64_t      BaseOffs = 0;



More information about the llvm-commits mailing list