[llvm] [LSR] Recognize vscale-relative immediates (PR #88124)

via llvm-commits llvm-commits at lists.llvm.org
Tue Apr 9 06:37:55 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-transforms

Author: Graham Hunter (huntergr-arm)

<details>
<summary>Changes</summary>

Final part of the vscale-aware LSR work, see https://discourse.llvm.org/t/rfc-vscale-aware-loopstrengthreduce/77131

It's a bit messy right now, I mainly just want to know if there's any objections to the current work before I finish it up.

---

Patch is 59.52 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/88124.diff


2 Files Affected:

- (modified) llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp (+372-160) 
- (added) llvm/test/Transforms/LoopStrengthReduce/AArch64/vscale-fixups.ll (+147) 


``````````diff
diff --git a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp
index ec42e2d6e193a6..b5d0113bafe023 100644
--- a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp
@@ -197,6 +197,14 @@ static cl::opt<bool> AllowDropSolutionIfLessProfitable(
     "lsr-drop-solution", cl::Hidden, cl::init(false),
     cl::desc("Attempt to drop solution if it is less profitable"));
 
+static cl::opt<bool> EnableVScaleImmediates(
+    "lsr-enable-vscale-immediates", cl::Hidden, cl::init(true),
+    cl::desc("Enable analysis of vscale-relative immediates in LSR"));
+
+static cl::opt<bool> DropScaledForVScale(
+    "lsr-drop-scaled-reg-for-vscale", cl::Hidden, cl::init(true),
+    cl::desc("Avoid using scaled registers with vscale-relative addressing"));
+
 STATISTIC(NumTermFold,
           "Number of terminating condition fold recognized and performed");
 
@@ -247,6 +255,68 @@ class RegSortData {
   void dump() const;
 };
 
+// An offset from an address that is either scalable or fixed. Used for
+// per-target optimizations of addressing modes.
+class Immediate : public details::FixedOrScalableQuantity<Immediate, int64_t> {
+  constexpr Immediate(ScalarTy MinVal, bool Scalable)
+      : FixedOrScalableQuantity(MinVal, Scalable) {}
+
+  constexpr Immediate(const FixedOrScalableQuantity<Immediate, int64_t> &V)
+      : FixedOrScalableQuantity(V) {}
+
+public:
+  constexpr Immediate() : FixedOrScalableQuantity() {}
+
+  static constexpr Immediate getFixed(ScalarTy MinVal) {
+    return Immediate(MinVal, false);
+  }
+  static constexpr Immediate getScalable(ScalarTy MinVal) {
+    return Immediate(MinVal, true);
+  }
+  static constexpr Immediate get(ScalarTy MinVal, bool Scalable) {
+    return Immediate(MinVal, Scalable);
+  }
+
+  constexpr bool isLessThanZero() const { return Quantity < 0; }
+
+  constexpr bool isGreaterThanZero() const { return Quantity > 0; }
+
+  constexpr bool isMin() const {
+    return Quantity == std::numeric_limits<ScalarTy>::min();
+  }
+
+  constexpr bool isMax() const {
+    return Quantity == std::numeric_limits<ScalarTy>::max();
+  }
+};
+
+// This is needed for the Compare type of std::map when Immediate is used
+// as a key. We don't need it to be fully correct against any value of vscale,
+// just to make sure that vscale-related terms in the map are considered against
+// each other rather than being mixed up and potentially missing opportunities.
+struct KeyOrderTargetImmediate {
+  bool operator()(const Immediate &LHS, const Immediate &RHS) const {
+    if (LHS.isScalable() && !RHS.isScalable())
+      return false;
+    if (!LHS.isScalable() && RHS.isScalable())
+      return true;
+    return LHS.getKnownMinValue() < RHS.getKnownMinValue();
+  }
+};
+
+// This would be nicer if we could be generic instead of directly using size_t,
+// but there doesn't seem to be a type trait for is_orderable or
+// is_lessthan_comparable or similar.
+struct KeyOrderSizeTAndImmediate {
+  bool operator()(const std::pair<size_t, Immediate> &LHS,
+                  const std::pair<size_t, Immediate> &RHS) const {
+    size_t LSize = LHS.first;
+    size_t RSize = RHS.first;
+    if (LSize != RSize)
+      return LSize < RSize;
+    return KeyOrderTargetImmediate()(LHS.second, RHS.second);
+  }
+};
 } // end anonymous namespace
 
 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
@@ -357,7 +427,7 @@ struct Formula {
   GlobalValue *BaseGV = nullptr;
 
   /// Base offset for complex addressing.
-  int64_t BaseOffset = 0;
+  Immediate BaseOffset;
 
   /// Whether any complex addressing has a base register.
   bool HasBaseReg = false;
@@ -388,7 +458,7 @@ struct Formula {
   /// An additional constant offset which added near the use. This requires a
   /// temporary register, but the offset itself can live in an add immediate
   /// field rather than a register.
-  int64_t UnfoldedOffset = 0;
+  Immediate UnfoldedOffset;
 
   Formula() = default;
 
@@ -628,7 +698,7 @@ void Formula::print(raw_ostream &OS) const {
     if (!First) OS << " + "; else First = false;
     BaseGV->printAsOperand(OS, /*PrintType=*/false);
   }
-  if (BaseOffset != 0) {
+  if (BaseOffset.isNonZero()) {
     if (!First) OS << " + "; else First = false;
     OS << BaseOffset;
   }
@@ -652,7 +722,7 @@ void Formula::print(raw_ostream &OS) const {
       OS << "<unknown>";
     OS << ')';
   }
-  if (UnfoldedOffset != 0) {
+  if (UnfoldedOffset.isNonZero()) {
     if (!First) OS << " + ";
     OS << "imm(" << UnfoldedOffset << ')';
   }
@@ -798,28 +868,34 @@ static const SCEV *getExactSDiv(const SCEV *LHS, const SCEV *RHS,
 
 /// If S involves the addition of a constant integer value, return that integer
 /// value, and mutate S to point to a new SCEV with that value excluded.
-static int64_t ExtractImmediate(const SCEV *&S, ScalarEvolution &SE) {
+static Immediate ExtractImmediate(const SCEV *&S, ScalarEvolution &SE) {
   if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S)) {
     if (C->getAPInt().getSignificantBits() <= 64) {
       S = SE.getConstant(C->getType(), 0);
-      return C->getValue()->getSExtValue();
+      return Immediate::getFixed(C->getValue()->getSExtValue());
     }
   } else if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(S)) {
     SmallVector<const SCEV *, 8> NewOps(Add->operands());
-    int64_t Result = ExtractImmediate(NewOps.front(), SE);
-    if (Result != 0)
+    Immediate Result = ExtractImmediate(NewOps.front(), SE);
+    if (Result.isNonZero())
       S = SE.getAddExpr(NewOps);
     return Result;
   } else if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(S)) {
     SmallVector<const SCEV *, 8> NewOps(AR->operands());
-    int64_t Result = ExtractImmediate(NewOps.front(), SE);
-    if (Result != 0)
+    Immediate Result = ExtractImmediate(NewOps.front(), SE);
+    if (Result.isNonZero())
       S = SE.getAddRecExpr(NewOps, AR->getLoop(),
                            // FIXME: AR->getNoWrapFlags(SCEV::FlagNW)
                            SCEV::FlagAnyWrap);
     return Result;
-  }
-  return 0;
+  } else if (EnableVScaleImmediates)
+    if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(S))
+      if (const SCEVConstant *C = dyn_cast<SCEVConstant>(M->getOperand(0)))
+        if (isa<SCEVVScale>(M->getOperand(1))) {
+          S = SE.getConstant(M->getType(), 0);
+          return Immediate::getScalable(C->getValue()->getSExtValue());
+        }
+  return Immediate();
 }
 
 /// If S involves the addition of a GlobalValue address, return that symbol, and
@@ -1134,7 +1210,7 @@ struct LSRFixup {
   /// A constant offset to be added to the LSRUse expression.  This allows
   /// multiple fixups to share the same LSRUse with different offsets, for
   /// example in an unrolled loop.
-  int64_t Offset = 0;
+  Immediate Offset;
 
   LSRFixup() = default;
 
@@ -1197,8 +1273,10 @@ class LSRUse {
   SmallVector<LSRFixup, 8> Fixups;
 
   /// Keep track of the min and max offsets of the fixups.
-  int64_t MinOffset = std::numeric_limits<int64_t>::max();
-  int64_t MaxOffset = std::numeric_limits<int64_t>::min();
+  Immediate MinOffset =
+      Immediate::getFixed(std::numeric_limits<int64_t>::max());
+  Immediate MaxOffset =
+      Immediate::getFixed(std::numeric_limits<int64_t>::min());
 
   /// This records whether all of the fixups using this LSRUse are outside of
   /// the loop, in which case some special-case heuristics may be used.
@@ -1234,9 +1312,9 @@ class LSRUse {
 
   void pushFixup(LSRFixup &f) {
     Fixups.push_back(f);
-    if (f.Offset > MaxOffset)
+    if (Immediate::isKnownGT(f.Offset, MaxOffset))
       MaxOffset = f.Offset;
-    if (f.Offset < MinOffset)
+    if (Immediate::isKnownLT(f.Offset, MinOffset))
       MinOffset = f.Offset;
   }
 
@@ -1254,7 +1332,7 @@ class LSRUse {
 
 static bool isAMCompletelyFolded(const TargetTransformInfo &TTI,
                                  LSRUse::KindType Kind, MemAccessTy AccessTy,
-                                 GlobalValue *BaseGV, int64_t BaseOffset,
+                                 GlobalValue *BaseGV, Immediate BaseOffset,
                                  bool HasBaseReg, int64_t Scale,
                                  Instruction *Fixup = nullptr);
 
@@ -1310,7 +1388,7 @@ void Cost::RateRegister(const Formula &F, const SCEV *Reg,
       // addressing.
       if (AMK == TTI::AMK_PreIndexed) {
         if (auto *Step = dyn_cast<SCEVConstant>(AR->getStepRecurrence(*SE)))
-          if (Step->getAPInt() == F.BaseOffset)
+          if (Step->getAPInt() == F.BaseOffset.getFixedValue())
             LoopCost = 0;
       } else if (AMK == TTI::AMK_PostIndexed) {
         const SCEV *LoopStep = AR->getStepRecurrence(*SE);
@@ -1401,24 +1479,29 @@ void Cost::RateFormula(const Formula &F,
     // allows to fold 2 registers.
     C.NumBaseAdds +=
         NumBaseParts - (1 + (F.Scale && isAMCompletelyFolded(*TTI, LU, F)));
-  C.NumBaseAdds += (F.UnfoldedOffset != 0);
+  C.NumBaseAdds += (F.UnfoldedOffset.isNonZero());
 
   // Accumulate non-free scaling amounts.
   C.ScaleCost += *getScalingFactorCost(*TTI, LU, F, *L).getValue();
 
   // Tally up the non-zero immediates.
   for (const LSRFixup &Fixup : LU.Fixups) {
-    int64_t O = Fixup.Offset;
-    int64_t Offset = (uint64_t)O + F.BaseOffset;
+    // FIXME: We probably want to noticeably increase the cost if the
+    // two offsets differ in scalability?
+    bool Scalable = Fixup.Offset.isScalable() || F.BaseOffset.isScalable();
+    int64_t O = Fixup.Offset.getKnownMinValue();
+    Immediate Offset = Immediate::get(
+        (uint64_t)(O) + F.BaseOffset.getKnownMinValue(), Scalable);
     if (F.BaseGV)
       C.ImmCost += 64; // Handle symbolic values conservatively.
                      // TODO: This should probably be the pointer size.
-    else if (Offset != 0)
-      C.ImmCost += APInt(64, Offset, true).getSignificantBits();
+    else if (Offset.isNonZero())
+      C.ImmCost +=
+          APInt(64, Offset.getKnownMinValue(), true).getSignificantBits();
 
     // Check with target if this offset with this instruction is
     // specifically not supported.
-    if (LU.Kind == LSRUse::Address && Offset != 0 &&
+    if (LU.Kind == LSRUse::Address && Offset.isNonZero() &&
         !isAMCompletelyFolded(*TTI, LSRUse::Address, LU.AccessTy, F.BaseGV,
                               Offset, F.HasBaseReg, F.Scale, Fixup.UserInst))
       C.NumBaseAdds++;
@@ -1546,7 +1629,7 @@ void LSRFixup::print(raw_ostream &OS) const {
     PIL->getHeader()->printAsOperand(OS, /*PrintType=*/false);
   }
 
-  if (Offset != 0)
+  if (Offset.isNonZero())
     OS << ", Offset=" << Offset;
 }
 
@@ -1673,14 +1756,19 @@ LLVM_DUMP_METHOD void LSRUse::dump() const {
 
 static bool isAMCompletelyFolded(const TargetTransformInfo &TTI,
                                  LSRUse::KindType Kind, MemAccessTy AccessTy,
-                                 GlobalValue *BaseGV, int64_t BaseOffset,
+                                 GlobalValue *BaseGV, Immediate BaseOffset,
                                  bool HasBaseReg, int64_t Scale,
-                                 Instruction *Fixup/*= nullptr*/) {
+                                 Instruction *Fixup /*= nullptr*/) {
   switch (Kind) {
-  case LSRUse::Address:
-    return TTI.isLegalAddressingMode(AccessTy.MemTy, BaseGV, BaseOffset,
-                                     HasBaseReg, Scale, AccessTy.AddrSpace, Fixup);
-
+  case LSRUse::Address: {
+    int64_t FixedOffset =
+        BaseOffset.isScalable() ? 0 : BaseOffset.getFixedValue();
+    int64_t ScalableOffset =
+        BaseOffset.isScalable() ? BaseOffset.getKnownMinValue() : 0;
+    return TTI.isLegalAddressingMode(AccessTy.MemTy, BaseGV, FixedOffset,
+                                     HasBaseReg, Scale, AccessTy.AddrSpace,
+                                     Fixup, ScalableOffset);
+  }
   case LSRUse::ICmpZero:
     // There's not even a target hook for querying whether it would be legal to
     // fold a GV into an ICmp.
@@ -1688,7 +1776,7 @@ static bool isAMCompletelyFolded(const TargetTransformInfo &TTI,
       return false;
 
     // ICmp only has two operands; don't allow more than two non-trivial parts.
-    if (Scale != 0 && HasBaseReg && BaseOffset != 0)
+    if (Scale != 0 && HasBaseReg && BaseOffset.isNonZero())
       return false;
 
     // ICmp only supports no scale or a -1 scale, as we can "fold" a -1 scale by
@@ -1698,7 +1786,7 @@ static bool isAMCompletelyFolded(const TargetTransformInfo &TTI,
 
     // If we have low-level target information, ask the target if it can fold an
     // integer immediate on an icmp.
-    if (BaseOffset != 0) {
+    if (BaseOffset.isNonZero()) {
       // We have one of:
       // ICmpZero     BaseReg + BaseOffset => ICmp BaseReg, -BaseOffset
       // ICmpZero -1*ScaleReg + BaseOffset => ICmp ScaleReg, BaseOffset
@@ -1706,8 +1794,8 @@ static bool isAMCompletelyFolded(const TargetTransformInfo &TTI,
       if (Scale == 0)
         // The cast does the right thing with
         // std::numeric_limits<int64_t>::min().
-        BaseOffset = -(uint64_t)BaseOffset;
-      return TTI.isLegalICmpImmediate(BaseOffset);
+        BaseOffset = BaseOffset.getFixed((uint64_t)BaseOffset.getFixedValue());
+      return TTI.isLegalICmpImmediate(BaseOffset.getFixedValue());
     }
 
     // ICmpZero BaseReg + -1*ScaleReg => ICmp BaseReg, ScaleReg
@@ -1715,30 +1803,36 @@ static bool isAMCompletelyFolded(const TargetTransformInfo &TTI,
 
   case LSRUse::Basic:
     // Only handle single-register values.
-    return !BaseGV && Scale == 0 && BaseOffset == 0;
+    return !BaseGV && Scale == 0 && BaseOffset.isZero();
 
   case LSRUse::Special:
     // Special case Basic to handle -1 scales.
-    return !BaseGV && (Scale == 0 || Scale == -1) && BaseOffset == 0;
+    return !BaseGV && (Scale == 0 || Scale == -1) && BaseOffset.isZero();
   }
 
   llvm_unreachable("Invalid LSRUse Kind!");
 }
 
 static bool isAMCompletelyFolded(const TargetTransformInfo &TTI,
-                                 int64_t MinOffset, int64_t MaxOffset,
+                                 Immediate MinOffset, Immediate MaxOffset,
                                  LSRUse::KindType Kind, MemAccessTy AccessTy,
-                                 GlobalValue *BaseGV, int64_t BaseOffset,
+                                 GlobalValue *BaseGV, Immediate BaseOffset,
                                  bool HasBaseReg, int64_t Scale) {
+  if (BaseOffset.isNonZero() &&
+      (BaseOffset.isScalable() != MinOffset.isScalable() ||
+       BaseOffset.isScalable() != MaxOffset.isScalable()))
+    return false;
+  // Check for overflow.
   // Check for overflow.
-  if (((int64_t)((uint64_t)BaseOffset + MinOffset) > BaseOffset) !=
-      (MinOffset > 0))
+  int64_t Base = BaseOffset.getKnownMinValue();
+  int64_t Min = MinOffset.getKnownMinValue();
+  int64_t Max = MaxOffset.getKnownMinValue();
+  if (((int64_t)((uint64_t)Base + Min) > Base) != (Min > 0))
     return false;
-  MinOffset = (uint64_t)BaseOffset + MinOffset;
-  if (((int64_t)((uint64_t)BaseOffset + MaxOffset) > BaseOffset) !=
-      (MaxOffset > 0))
+  MinOffset = Immediate::get((uint64_t)Base + Min, MinOffset.isScalable());
+  if (((int64_t)((uint64_t)Base + Max) > Base) != (Max > 0))
     return false;
-  MaxOffset = (uint64_t)BaseOffset + MaxOffset;
+  MaxOffset = Immediate::get((uint64_t)Base + Max, MaxOffset.isScalable());
 
   return isAMCompletelyFolded(TTI, Kind, AccessTy, BaseGV, MinOffset,
                               HasBaseReg, Scale) &&
@@ -1747,7 +1841,7 @@ static bool isAMCompletelyFolded(const TargetTransformInfo &TTI,
 }
 
 static bool isAMCompletelyFolded(const TargetTransformInfo &TTI,
-                                 int64_t MinOffset, int64_t MaxOffset,
+                                 Immediate MinOffset, Immediate MaxOffset,
                                  LSRUse::KindType Kind, MemAccessTy AccessTy,
                                  const Formula &F, const Loop &L) {
   // For the purpose of isAMCompletelyFolded either having a canonical formula
@@ -1763,10 +1857,10 @@ static bool isAMCompletelyFolded(const TargetTransformInfo &TTI,
 }
 
 /// Test whether we know how to expand the current formula.
-static bool isLegalUse(const TargetTransformInfo &TTI, int64_t MinOffset,
-                       int64_t MaxOffset, LSRUse::KindType Kind,
+static bool isLegalUse(const TargetTransformInfo &TTI, Immediate MinOffset,
+                       Immediate MaxOffset, LSRUse::KindType Kind,
                        MemAccessTy AccessTy, GlobalValue *BaseGV,
-                       int64_t BaseOffset, bool HasBaseReg, int64_t Scale) {
+                       Immediate BaseOffset, bool HasBaseReg, int64_t Scale) {
   // We know how to expand completely foldable formulae.
   return isAMCompletelyFolded(TTI, MinOffset, MaxOffset, Kind, AccessTy, BaseGV,
                               BaseOffset, HasBaseReg, Scale) ||
@@ -1777,13 +1871,21 @@ static bool isLegalUse(const TargetTransformInfo &TTI, int64_t MinOffset,
                                BaseGV, BaseOffset, true, 0));
 }
 
-static bool isLegalUse(const TargetTransformInfo &TTI, int64_t MinOffset,
-                       int64_t MaxOffset, LSRUse::KindType Kind,
+static bool isLegalUse(const TargetTransformInfo &TTI, Immediate MinOffset,
+                       Immediate MaxOffset, LSRUse::KindType Kind,
                        MemAccessTy AccessTy, const Formula &F) {
   return isLegalUse(TTI, MinOffset, MaxOffset, Kind, AccessTy, F.BaseGV,
                     F.BaseOffset, F.HasBaseReg, F.Scale);
 }
 
+static bool isLegalAddImmediate(const TargetTransformInfo &TTI,
+                                Immediate Offset) {
+  if (Offset.isScalable())
+    return TTI.isLegalAddScalableImmediate(Offset.getKnownMinValue());
+
+  return TTI.isLegalAddImmediate(Offset.getFixedValue());
+}
+
 static bool isAMCompletelyFolded(const TargetTransformInfo &TTI,
                                  const LSRUse &LU, const Formula &F) {
   // Target may want to look at the user instructions.
@@ -1817,11 +1919,13 @@ static InstructionCost getScalingFactorCost(const TargetTransformInfo &TTI,
   case LSRUse::Address: {
     // Check the scaling factor cost with both the min and max offsets.
     InstructionCost ScaleCostMinOffset = TTI.getScalingFactorCost(
-        LU.AccessTy.MemTy, F.BaseGV, F.BaseOffset + LU.MinOffset, F.HasBaseReg,
-        F.Scale, LU.AccessTy.AddrSpace);
+        LU.AccessTy.MemTy, F.BaseGV,
+        F.BaseOffset.getFixedValue() + LU.MinOffset.getFixedValue(),
+        F.HasBaseReg, F.Scale, LU.AccessTy.AddrSpace);
     InstructionCost ScaleCostMaxOffset = TTI.getScalingFactorCost(
-        LU.AccessTy.MemTy, F.BaseGV, F.BaseOffset + LU.MaxOffset, F.HasBaseReg,
-        F.Scale, LU.AccessTy.AddrSpace);
+        LU.AccessTy.MemTy, F.BaseGV,
+        F.BaseOffset.getFixedValue() + LU.MaxOffset.getFixedValue(),
+        F.HasBaseReg, F.Scale, LU.AccessTy.AddrSpace);
 
     assert(ScaleCostMinOffset.isValid() && ScaleCostMaxOffset.isValid() &&
            "Legal addressing mode has an illegal cost!");
@@ -1840,10 +1944,11 @@ static InstructionCost getScalingFactorCost(const TargetTransformInfo &TTI,
 
 static bool isAlwaysFoldable(const TargetTransformInfo &TTI,
                              LSRUse::KindType Kind, MemAccessTy AccessTy,
-                             GlobalValue *BaseGV, int64_t BaseOffset,
+                             GlobalValue *BaseGV, Immediate BaseOffset,
                              bool HasBaseReg) {
   // Fast-path: zero is always foldable.
-  if (BaseOffset == 0 && !BaseGV) return true;
+  if (BaseOffset.isZero() && !BaseGV)
+    return true;
 
   // Conservatively, create an address with an immediate and a
   // base and a scale.
@@ -1856,13 +1961,22 @@ static bool isAlwaysFoldable(const TargetTransformInfo &TTI,
     HasBaseReg = true;
   }
 
+  // FIXME: Try with + without a scale? Maybe based on TTI?
+  // I think basereg + scaledreg + immediateoffset isn't a good 'conservative'
+  // default for many architectures, not just AArch64 SVE. More investigation
+  // needed later to determine if this should be used more widely than just
+  // on scalable types.
+  if (HasBaseReg && BaseOffset.isNonZero() && Kind != LSRUse::ICmpZero &&
+      AccessTy.MemTy && AccessTy.MemTy->isScalableTy() && DropScaledForVScale)
+    Scale = 0;
+
   return isAMCompletelyFolded(TTI, Kind...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/88124


More information about the llvm-commits mailing list