[llvm] Limit Number of Values used for SCEV Calculation (PR #140565)

Manish Kausik H via llvm-commits llvm-commits at lists.llvm.org
Mon May 19 09:01:36 PDT 2025


https://github.com/Nirhar created https://github.com/llvm/llvm-project/pull/140565

In large IRs, it is possible that the number of Values involved in a certain SCEV calculation are too many, and because SCEV is calculated recursively, this can lead to a Stack Overflow Error, in trying to compute SCEV for all those values. This patch addresses this issue by limiting SCEV calculation when number of Values involved in the calculation crosses some threshold.

## A little bit of Background:
At Azul Systems, our Fuzzer generated a test case, the IR for which was pretty huge, and while attempting to calculate SCEV on the IR, SCEV crashed with a Stack overflow error. After some investigation we found that a certain SCEV calculation needed 6.2 MB of stack space to successfully complete, whereas our default Stack Size limit was set to 2 MB. This is a case where SCEV calculation is very memory intensive, and we need some way by which we can gracefully stop SCEV calculation if it exceeds some threshold. I thought having a threshold on number of Values involved in the SCEV calculation might be a good way to implement this constraint, but I am open to suggestions on how better this can be implemented.

I am not very familiar with SCEV code and its use-cases, and I am still acclimatising myself to it. I chose to implement this patch by having a State Variable `MaxNumOfValuesToConsider` in the ScalarEvolution class because I considered it easier than the alternative, ie, Modify the internal and External API of ScalarEvolution class to include `MaxNumberOfValuesToConsider` as a function argument. However, I am open to suggestions on how this can be better implemented. 

I will add lit tests/modify failing tests once the design of the patch is finalized.

cc @nikic 

>From 15c6e34f846141a5250c07030e9eb25735053662 Mon Sep 17 00:00:00 2001
From: Manish Kausik H <hmanishkausik at gmail.com>
Date: Mon, 19 May 2025 21:00:31 +0530
Subject: [PATCH] Limit Number of Values used for SCEV Calculation

In large IRs, it is possible that the number of Values involved in a certain
SCEV calculation are too many, and because SCEV is calculated recursively,
this can lead to a Stack Overflow Error, in trying to compute SCEV for all
those values. This patch addresses this issue by limiting SCEV calculation
when number of Values involved in the calculation crosses some threshold.
---
 llvm/include/llvm/Analysis/ScalarEvolution.h |   7 +
 llvm/lib/Analysis/ScalarEvolution.cpp        | 206 +++++++++++--------
 2 files changed, 129 insertions(+), 84 deletions(-)

diff --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h
index 339bdfeb4956a..abdb32ca6c404 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolution.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolution.h
@@ -1492,6 +1492,9 @@ class ScalarEvolution {
   /// Memoized values for the getConstantMultiple
   DenseMap<const SCEV *, APInt> ConstantMultipleCache;
 
+  // Holds how many more Values we can consider in current SCEV calculation
+  std::optional<unsigned> MaxNumOfValuesToConsider;
+
   /// Return the Value set from which the SCEV expr is generated.
   ArrayRef<Value *> getSCEVValues(const SCEV *S);
 
@@ -1771,6 +1774,10 @@ class ScalarEvolution {
   /// SCEVUnknowns and thus don't use this mechanism.
   ConstantRange getRangeForUnknownRecurrence(const SCEVUnknown *U);
 
+  /// Return a SCEV expression for the full generality of the specified
+  /// expression.
+  const SCEV *getSCEVWithLimits(Value *V);
+
   /// We know that there is no SCEV for the specified value.  Analyze the
   /// expression recursively.
   const SCEV *createSCEV(Value *V);
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 5e01c29615fab..7fdc7d8dba9bb 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -256,6 +256,18 @@ static cl::opt<bool> UseContextForNoWrapFlagInference(
     cl::desc("Infer nuw/nsw flags using context where suitable"),
     cl::init(true));
 
+static cl::opt<bool> LimitNumValuesInSCEVCalculation(
+    "limit-num-values-in-scev-calculation", cl::Hidden,
+    cl::desc("Limit the number of Values to consider for Calculating SCEV of a "
+             "Value"),
+    cl::init(true));
+
+static cl::opt<unsigned> MaxValuesInSCEVCalculation(
+    "max-values-in-scev-calculation", cl::Hidden,
+    cl::desc("The Maximum number of Values to consider for Calculating SCEV of "
+             "a Value, if LimitNumValuesInSCEVCalculation is set"),
+    cl::init(100));
+
 //===----------------------------------------------------------------------===//
 //                           SCEV class definitions
 //===----------------------------------------------------------------------===//
@@ -4533,10 +4545,32 @@ void ScalarEvolution::insertValueToMap(Value *V, const SCEV *S) {
 /// Return an existing SCEV if it exists, otherwise analyze the expression and
 /// create a new one.
 const SCEV *ScalarEvolution::getSCEV(Value *V) {
+  if (LimitNumValuesInSCEVCalculation)
+    MaxNumOfValuesToConsider = MaxValuesInSCEVCalculation;
+  else
+    MaxNumOfValuesToConsider = std::nullopt;
+
+  unsigned MaxCopy = *MaxNumOfValuesToConsider;
+  return getSCEVWithLimits(V);
+}
+
+// TODO: Give a better name for this function, Maybe?
+const SCEV *ScalarEvolution::getSCEVWithLimits(Value *V) {
   assert(isSCEVable(V->getType()) && "Value is not SCEVable!");
 
   if (const SCEV *S = getExistingSCEV(V))
     return S;
+
+  if (LimitNumValuesInSCEVCalculation) {
+    assert(MaxNumOfValuesToConsider.has_value());
+    if (*MaxNumOfValuesToConsider == 0)
+      // This SCEV Calculation involves calculating SCEV for too many Values.
+      // We should bail out.
+      return getUnknown(V);
+    else
+      MaxNumOfValuesToConsider = *MaxNumOfValuesToConsider - 1;
+  }
+
   return createSCEVIter(V);
 }
 
@@ -5470,7 +5504,7 @@ ScalarEvolution::createAddRecFromPHIWithCastsImpl(const SCEVUnknown *SymbolicPHI
   if (!BEValueV || !StartValueV)
     return std::nullopt;
 
-  const SCEV *BEValue = getSCEV(BEValueV);
+  const SCEV *BEValue = getSCEVWithLimits(BEValueV);
 
   // If the value coming around the backedge is an add with the symbolic
   // value we just inserted, possibly with casts that we can ignore under
@@ -5558,7 +5592,7 @@ ScalarEvolution::createAddRecFromPHIWithCastsImpl(const SCEVUnknown *SymbolicPHI
   //
 
   // Create a truncated addrec for which we will add a no overflow check (P1).
-  const SCEV *StartVal = getSCEV(StartValueV);
+  const SCEV *StartVal = getSCEVWithLimits(StartValueV);
   const SCEV *PHISCEV =
       getAddRecExpr(getTruncateExpr(StartVal, TruncTy),
                     getTruncateExpr(Accum, TruncTy), L, SCEV::FlagAnyWrap);
@@ -5729,9 +5763,9 @@ const SCEV *ScalarEvolution::createSimpleAffineAddRec(PHINode *PN,
 
   const SCEV *Accum = nullptr;
   if (BO->LHS == PN && L->isLoopInvariant(BO->RHS))
-    Accum = getSCEV(BO->RHS);
+    Accum = getSCEVWithLimits(BO->RHS);
   else if (BO->RHS == PN && L->isLoopInvariant(BO->LHS))
-    Accum = getSCEV(BO->LHS);
+    Accum = getSCEVWithLimits(BO->LHS);
 
   if (!Accum)
     return nullptr;
@@ -5742,7 +5776,7 @@ const SCEV *ScalarEvolution::createSimpleAffineAddRec(PHINode *PN,
   if (BO->IsNSW)
     Flags = setFlags(Flags, SCEV::FlagNSW);
 
-  const SCEV *StartVal = getSCEV(StartValueV);
+  const SCEV *StartVal = getSCEVWithLimits(StartValueV);
   const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags);
   insertValueToMap(PN, PHISCEV);
 
@@ -5807,7 +5841,7 @@ const SCEV *ScalarEvolution::createAddRecFromPHI(PHINode *PN) {
 
   // Using this symbolic name for the PHI, analyze the value coming around
   // the back-edge.
-  const SCEV *BEValue = getSCEV(BEValueV);
+  const SCEV *BEValue = getSCEVWithLimits(BEValueV);
 
   // NOTE: If BEValue is loop invariant, we know that the PHI node just
   // has a special value for the first iteration of the loop.
@@ -5868,7 +5902,7 @@ const SCEV *ScalarEvolution::createAddRecFromPHI(PHINode *PN) {
           // for instance.
         }
 
-        const SCEV *StartVal = getSCEV(StartValueV);
+        const SCEV *StartVal = getSCEVWithLimits(StartValueV);
         const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags);
 
         // Okay, for the entire analysis of this edge we assumed the PHI
@@ -5909,7 +5943,7 @@ const SCEV *ScalarEvolution::createAddRecFromPHI(PHINode *PN) {
     const SCEV *Start = SCEVInitRewriter::rewrite(Shifted, L, *this, false);
     if (Shifted != getCouldNotCompute() && Start != getCouldNotCompute() &&
         isGuaranteedNotToCauseUB(Shifted) && ::impliesPoison(Shifted, Start)) {
-      const SCEV *StartVal = getSCEV(StartValueV);
+      const SCEV *StartVal = getSCEVWithLimits(StartValueV);
       if (Start == StartVal) {
         // Okay, for the entire analysis of this edge we assumed the PHI
         // to be symbolic.  We now need to go back and purge all of the
@@ -5987,8 +6021,8 @@ const SCEV *ScalarEvolution::createNodeFromSelectLikePHI(PHINode *PN) {
 
     if (BI && BI->isConditional() &&
         BrPHIToSelect(DT, BI, PN, Cond, LHS, RHS) &&
-        properlyDominates(getSCEV(LHS), PN->getParent()) &&
-        properlyDominates(getSCEV(RHS), PN->getParent()))
+        properlyDominates(getSCEVWithLimits(LHS), PN->getParent()) &&
+        properlyDominates(getSCEVWithLimits(RHS), PN->getParent()))
       return createNodeForSelectOrPHI(PN, Cond, LHS, RHS);
   }
 
@@ -6024,10 +6058,11 @@ ScalarEvolution::createNodeForPHIWithIdenticalOperands(PHINode *PN) {
     return nullptr;
 
   // Check if SCEV exprs for instructions are identical.
-  const SCEV *CommonSCEV = getSCEV(CommonInst);
+  const SCEV *CommonSCEV = getSCEVWithLimits(CommonInst);
   bool SCEVExprsIdentical =
-      all_of(drop_begin(PN->incoming_values()),
-             [this, CommonSCEV](Value *V) { return CommonSCEV == getSCEV(V); });
+      all_of(drop_begin(PN->incoming_values()), [this, CommonSCEV](Value *V) {
+        return CommonSCEV == getSCEVWithLimits(V);
+      });
   return SCEVExprsIdentical ? CommonSCEV : nullptr;
 }
 
@@ -6040,7 +6075,7 @@ const SCEV *ScalarEvolution::createNodeForPHI(PHINode *PN) {
   if (Value *V = simplifyInstruction(
           PN, {getDataLayout(), &TLI, &DT, &AC, /*CtxI=*/nullptr,
                /*UseInstrInfo=*/true, /*CanUseUndef=*/false}))
-    return getSCEV(V);
+    return getSCEVWithLimits(V);
 
   if (const SCEV *S = createNodeForPHIWithIdenticalOperands(PN))
     return S;
@@ -6114,10 +6149,10 @@ ScalarEvolution::createNodeForSelectOrPHIInstWithICmpInstCond(Type *Ty,
     // a > b ? b+x : a+x  ->  min(a, b)+x
     if (getTypeSizeInBits(LHS->getType()) <= getTypeSizeInBits(Ty)) {
       bool Signed = ICI->isSigned();
-      const SCEV *LA = getSCEV(TrueVal);
-      const SCEV *RA = getSCEV(FalseVal);
-      const SCEV *LS = getSCEV(LHS);
-      const SCEV *RS = getSCEV(RHS);
+      const SCEV *LA = getSCEVWithLimits(TrueVal);
+      const SCEV *RA = getSCEVWithLimits(FalseVal);
+      const SCEV *LS = getSCEVWithLimits(LHS);
+      const SCEV *RS = getSCEVWithLimits(RHS);
       if (LA->getType()->isPointerTy()) {
         // FIXME: Handle cases where LS/RS are pointers not equal to LA/RA.
         // Need to make sure we can't produce weird expressions involving
@@ -6163,9 +6198,9 @@ ScalarEvolution::createNodeForSelectOrPHIInstWithICmpInstCond(Type *Ty,
     // x == 0 ? C+y : x+y  ->  umax(x, C)+y   iff C u<= 1
     if (getTypeSizeInBits(LHS->getType()) <= getTypeSizeInBits(Ty) &&
         isa<ConstantInt>(RHS) && cast<ConstantInt>(RHS)->isZero()) {
-      const SCEV *X = getNoopOrZeroExtend(getSCEV(LHS), Ty);
-      const SCEV *TrueValExpr = getSCEV(TrueVal);    // C+y
-      const SCEV *FalseValExpr = getSCEV(FalseVal);  // x+y
+      const SCEV *X = getNoopOrZeroExtend(getSCEVWithLimits(LHS), Ty);
+      const SCEV *TrueValExpr = getSCEVWithLimits(TrueVal);   // C+y
+      const SCEV *FalseValExpr = getSCEVWithLimits(FalseVal); // x+y
       const SCEV *Y = getMinusSCEV(FalseValExpr, X); // y = (x+y)-x
       const SCEV *C = getMinusSCEV(TrueValExpr, Y);  // C = (C+y)-y
       if (isa<SCEVConstant>(C) && cast<SCEVConstant>(C)->getAPInt().ule(1))
@@ -6177,11 +6212,11 @@ ScalarEvolution::createNodeForSelectOrPHIInstWithICmpInstCond(Type *Ty,
     //                    ->  umin_seq(x, umin (..., umin_seq(...), ...))
     if (isa<ConstantInt>(RHS) && cast<ConstantInt>(RHS)->isZero() &&
         isa<ConstantInt>(TrueVal) && cast<ConstantInt>(TrueVal)->isZero()) {
-      const SCEV *X = getSCEV(LHS);
+      const SCEV *X = getSCEVWithLimits(LHS);
       while (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(X))
         X = ZExt->getOperand();
       if (getTypeSizeInBits(X->getType()) <= getTypeSizeInBits(Ty)) {
-        const SCEV *FalseValExpr = getSCEV(FalseVal);
+        const SCEV *FalseValExpr = getSCEVWithLimits(FalseVal);
         if (SCEVMinMaxExprContains(FalseValExpr, X, scSequentialUMinExpr))
           return getUMinExpr(getNoopOrZeroExtend(X, Ty), FalseValExpr,
                              /*Sequential=*/true);
@@ -6264,7 +6299,7 @@ const SCEV *ScalarEvolution::createNodeForSelectOrPHI(Value *V, Value *Cond,
   // Handle "constant" branch or select. This can occur for instance when a
   // loop pass transforms an inner loop and moves on to process the outer loop.
   if (auto *CI = dyn_cast<ConstantInt>(Cond))
-    return getSCEV(CI->isOne() ? TrueVal : FalseVal);
+    return getSCEVWithLimits(CI->isOne() ? TrueVal : FalseVal);
 
   if (auto *I = dyn_cast<Instruction>(V)) {
     if (auto *ICI = dyn_cast<ICmpInst>(Cond)) {
@@ -6286,7 +6321,7 @@ const SCEV *ScalarEvolution::createNodeForGEP(GEPOperator *GEP) {
 
   SmallVector<const SCEV *, 4> IndexExprs;
   for (Value *Index : GEP->indices())
-    IndexExprs.push_back(getSCEV(Index));
+    IndexExprs.push_back(getSCEVWithLimits(Index));
   return getGEPExpr(GEP, IndexExprs);
 }
 
@@ -6931,7 +6966,8 @@ const ConstantRange &ScalarEvolution::getRangeRef(
         ConstantRange RangeFromOps(BitWidth, /*isFullSet=*/false);
 
         for (const auto &Op : Phi->operands()) {
-          auto OpRange = getRangeRef(getSCEV(Op), SignHint, Depth + 1);
+          auto OpRange =
+              getRangeRef(getSCEVWithLimits(Op), SignHint, Depth + 1);
           RangeFromOps = RangeFromOps.unionWith(OpRange);
           // No point to continue if we already have a full set.
           if (RangeFromOps.isFullSet())
@@ -7361,7 +7397,7 @@ bool ScalarEvolution::isSCEVExprNeverPoison(const Instruction *I) {
     // I could be an extractvalue from a call to an overflow intrinsic.
     // TODO: We can do better here in some cases.
     if (isSCEVable(Op->getType()))
-      SCEVOps.push_back(getSCEV(Op));
+      SCEVOps.push_back(getSCEVWithLimits(Op));
   }
   auto *DefI = getDefiningScopeBound(SCEVOps);
   return isGuaranteedToTransferExecutionTo(DefI, I);
@@ -7725,10 +7761,10 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) {
           // since the flags are only known to apply to this particular
           // addition - they may not apply to other additions that can be
           // formed with operands from AddOps.
-          const SCEV *RHS = getSCEV(BO->RHS);
+          const SCEV *RHS = getSCEVWithLimits(BO->RHS);
           SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(BO->Op);
           if (Flags != SCEV::FlagAnyWrap) {
-            const SCEV *LHS = getSCEV(BO->LHS);
+            const SCEV *LHS = getSCEVWithLimits(BO->LHS);
             if (BO->Opcode == Instruction::Sub)
               AddOps.push_back(getMinusSCEV(LHS, RHS, Flags));
             else
@@ -7738,15 +7774,15 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) {
         }
 
         if (BO->Opcode == Instruction::Sub)
-          AddOps.push_back(getNegativeSCEV(getSCEV(BO->RHS)));
+          AddOps.push_back(getNegativeSCEV(getSCEVWithLimits(BO->RHS)));
         else
-          AddOps.push_back(getSCEV(BO->RHS));
+          AddOps.push_back(getSCEVWithLimits(BO->RHS));
 
         auto NewBO = MatchBinaryOp(BO->LHS, getDataLayout(), AC, DT,
                                    dyn_cast<Instruction>(V));
         if (!NewBO || (NewBO->Opcode != Instruction::Add &&
                        NewBO->Opcode != Instruction::Sub)) {
-          AddOps.push_back(getSCEV(BO->LHS));
+          AddOps.push_back(getSCEVWithLimits(BO->LHS));
           break;
         }
         BO = NewBO;
@@ -7766,18 +7802,18 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) {
 
           SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(BO->Op);
           if (Flags != SCEV::FlagAnyWrap) {
-            LHS = getSCEV(BO->LHS);
-            RHS = getSCEV(BO->RHS);
+            LHS = getSCEVWithLimits(BO->LHS);
+            RHS = getSCEVWithLimits(BO->RHS);
             MulOps.push_back(getMulExpr(LHS, RHS, Flags));
             break;
           }
         }
 
-        MulOps.push_back(getSCEV(BO->RHS));
+        MulOps.push_back(getSCEVWithLimits(BO->RHS));
         auto NewBO = MatchBinaryOp(BO->LHS, getDataLayout(), AC, DT,
                                    dyn_cast<Instruction>(V));
         if (!NewBO || NewBO->Opcode != Instruction::Mul) {
-          MulOps.push_back(getSCEV(BO->LHS));
+          MulOps.push_back(getSCEVWithLimits(BO->LHS));
           break;
         }
         BO = NewBO;
@@ -7786,19 +7822,19 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) {
       return getMulExpr(MulOps);
     }
     case Instruction::UDiv:
-      LHS = getSCEV(BO->LHS);
-      RHS = getSCEV(BO->RHS);
+      LHS = getSCEVWithLimits(BO->LHS);
+      RHS = getSCEVWithLimits(BO->RHS);
       return getUDivExpr(LHS, RHS);
     case Instruction::URem:
-      LHS = getSCEV(BO->LHS);
-      RHS = getSCEV(BO->RHS);
+      LHS = getSCEVWithLimits(BO->LHS);
+      RHS = getSCEVWithLimits(BO->RHS);
       return getURemExpr(LHS, RHS);
     case Instruction::Sub: {
       SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap;
       if (BO->Op)
         Flags = getNoWrapFlagsFromUB(BO->Op);
-      LHS = getSCEV(BO->LHS);
-      RHS = getSCEV(BO->RHS);
+      LHS = getSCEVWithLimits(BO->LHS);
+      RHS = getSCEVWithLimits(BO->RHS);
       return getMinusSCEV(LHS, RHS, Flags);
     }
     case Instruction::And:
@@ -7806,9 +7842,9 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) {
       // use zext(trunc(x)) as the SCEV expression.
       if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) {
         if (CI->isZero())
-          return getSCEV(BO->RHS);
+          return getSCEVWithLimits(BO->RHS);
         if (CI->isMinusOne())
-          return getSCEV(BO->LHS);
+          return getSCEVWithLimits(BO->LHS);
         const APInt &A = CI->getValue();
 
         // Instcombine's ShrinkDemandedConstant may strip bits out of
@@ -7826,7 +7862,7 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) {
             APInt::getLowBitsSet(BitWidth, BitWidth - LZ - TZ).shl(TZ);
         if ((LZ != 0 || TZ != 0) && !((~A & ~Known.Zero) & EffectiveMask)) {
           const SCEV *MulCount = getConstant(APInt::getOneBitSet(BitWidth, TZ));
-          const SCEV *LHS = getSCEV(BO->LHS);
+          const SCEV *LHS = getSCEVWithLimits(BO->LHS);
           const SCEV *ShiftedLHS = nullptr;
           if (auto *LHSMul = dyn_cast<SCEVMulExpr>(LHS)) {
             if (auto *OpC = dyn_cast<SCEVConstant>(LHSMul->getOperand(0))) {
@@ -7853,8 +7889,8 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) {
       }
       // Binary `and` is a bit-wise `umin`.
       if (BO->LHS->getType()->isIntegerTy(1)) {
-        LHS = getSCEV(BO->LHS);
-        RHS = getSCEV(BO->RHS);
+        LHS = getSCEVWithLimits(BO->LHS);
+        RHS = getSCEVWithLimits(BO->RHS);
         return getUMinExpr(LHS, RHS);
       }
       break;
@@ -7862,8 +7898,8 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) {
     case Instruction::Or:
       // Binary `or` is a bit-wise `umax`.
       if (BO->LHS->getType()->isIntegerTy(1)) {
-        LHS = getSCEV(BO->LHS);
-        RHS = getSCEV(BO->RHS);
+        LHS = getSCEVWithLimits(BO->LHS);
+        RHS = getSCEVWithLimits(BO->RHS);
         return getUMaxExpr(LHS, RHS);
       }
       break;
@@ -7872,7 +7908,7 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) {
       if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) {
         // If the RHS of xor is -1, then this is a not operation.
         if (CI->isMinusOne())
-          return getNotSCEV(getSCEV(BO->LHS));
+          return getNotSCEV(getSCEVWithLimits(BO->LHS));
 
         // Model xor(and(x, C), C) as and(~x, C), if C is a low-bits mask.
         // This is a variant of the check for xor with -1, and it handles
@@ -7882,8 +7918,8 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) {
           if (ConstantInt *LCI = dyn_cast<ConstantInt>(LBO->getOperand(1)))
             if (LBO->getOpcode() == Instruction::And &&
                 LCI->getValue() == CI->getValue())
-              if (const SCEVZeroExtendExpr *Z =
-                      dyn_cast<SCEVZeroExtendExpr>(getSCEV(BO->LHS))) {
+              if (const SCEVZeroExtendExpr *Z = dyn_cast<SCEVZeroExtendExpr>(
+                      getSCEVWithLimits(BO->LHS))) {
                 Type *UTy = BO->LHS->getType();
                 const SCEV *Z0 = Z->getOperand();
                 Type *Z0Ty = Z0->getType();
@@ -7935,7 +7971,7 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) {
 
         ConstantInt *X = ConstantInt::get(
             getContext(), APInt::getOneBitSet(BitWidth, SA->getZExtValue()));
-        return getMulExpr(getSCEV(BO->LHS), getConstant(X), Flags);
+        return getMulExpr(getSCEVWithLimits(BO->LHS), getConstant(X), Flags);
       }
       break;
 
@@ -7955,7 +7991,7 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) {
         break;
 
       if (CI->isZero())
-        return getSCEV(BO->LHS); // shift by zero --> noop
+        return getSCEVWithLimits(BO->LHS); // shift by zero --> noop
 
       uint64_t AShrAmt = CI->getZExtValue();
       Type *TruncTy = IntegerType::get(getContext(), BitWidth - AShrAmt);
@@ -7975,7 +8011,7 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) {
         ConstantInt *AddOperandCI = dyn_cast<ConstantInt>(L->getOperand(1));
         if (LShift && LShift->getOpcode() == Instruction::Shl) {
           if (AddOperandCI) {
-            const SCEV *ShlOp0SCEV = getSCEV(LShift->getOperand(0));
+            const SCEV *ShlOp0SCEV = getSCEVWithLimits(LShift->getOperand(0));
             ShlAmtCI = dyn_cast<ConstantInt>(LShift->getOperand(1));
             // since we truncate to TruncTy, the AddConstant should be of the
             // same type, so create a new Constant with type same as TruncTy.
@@ -7993,7 +8029,7 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) {
         // Y = AShr X, m
         // Both n and m are constant.
 
-        const SCEV *ShlOp0SCEV = getSCEV(L->getOperand(0));
+        const SCEV *ShlOp0SCEV = getSCEVWithLimits(L->getOperand(0));
         ShlAmtCI = dyn_cast<ConstantInt>(L->getOperand(1));
         AddTruncateExpr = getTruncateExpr(ShlOp0SCEV, TruncTy);
       }
@@ -8028,10 +8064,10 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) {
 
   switch (U->getOpcode()) {
   case Instruction::Trunc:
-    return getTruncateExpr(getSCEV(U->getOperand(0)), U->getType());
+    return getTruncateExpr(getSCEVWithLimits(U->getOperand(0)), U->getType());
 
   case Instruction::ZExt:
-    return getZeroExtendExpr(getSCEV(U->getOperand(0)), U->getType());
+    return getZeroExtendExpr(getSCEVWithLimits(U->getOperand(0)), U->getType());
 
   case Instruction::SExt:
     if (auto BO = MatchBinaryOp(U->getOperand(0), getDataLayout(), AC, DT,
@@ -8045,22 +8081,22 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) {
       // but by that point the NSW information has potentially been lost.
       if (BO->Opcode == Instruction::Sub && BO->IsNSW) {
         Type *Ty = U->getType();
-        auto *V1 = getSignExtendExpr(getSCEV(BO->LHS), Ty);
-        auto *V2 = getSignExtendExpr(getSCEV(BO->RHS), Ty);
+        auto *V1 = getSignExtendExpr(getSCEVWithLimits(BO->LHS), Ty);
+        auto *V2 = getSignExtendExpr(getSCEVWithLimits(BO->RHS), Ty);
         return getMinusSCEV(V1, V2, SCEV::FlagNSW);
       }
     }
-    return getSignExtendExpr(getSCEV(U->getOperand(0)), U->getType());
+    return getSignExtendExpr(getSCEVWithLimits(U->getOperand(0)), U->getType());
 
   case Instruction::BitCast:
     // BitCasts are no-op casts so we just eliminate the cast.
     if (isSCEVable(U->getType()) && isSCEVable(U->getOperand(0)->getType()))
-      return getSCEV(U->getOperand(0));
+      return getSCEVWithLimits(U->getOperand(0));
     break;
 
   case Instruction::PtrToInt: {
     // Pointer to integer cast is straight-forward, so do model it.
-    const SCEV *Op = getSCEV(U->getOperand(0));
+    const SCEV *Op = getSCEVWithLimits(U->getOperand(0));
     Type *DstIntTy = U->getType();
     // But only if effective SCEV (integer) type is wide enough to represent
     // all possible pointer values.
@@ -8075,16 +8111,18 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) {
 
   case Instruction::SDiv:
     // If both operands are non-negative, this is just an udiv.
-    if (isKnownNonNegative(getSCEV(U->getOperand(0))) &&
-        isKnownNonNegative(getSCEV(U->getOperand(1))))
-      return getUDivExpr(getSCEV(U->getOperand(0)), getSCEV(U->getOperand(1)));
+    if (isKnownNonNegative(getSCEVWithLimits(U->getOperand(0))) &&
+        isKnownNonNegative(getSCEVWithLimits(U->getOperand(1))))
+      return getUDivExpr(getSCEVWithLimits(U->getOperand(0)),
+                         getSCEVWithLimits(U->getOperand(1)));
     break;
 
   case Instruction::SRem:
     // If both operands are non-negative, this is just an urem.
-    if (isKnownNonNegative(getSCEV(U->getOperand(0))) &&
-        isKnownNonNegative(getSCEV(U->getOperand(1))))
-      return getURemExpr(getSCEV(U->getOperand(0)), getSCEV(U->getOperand(1)));
+    if (isKnownNonNegative(getSCEVWithLimits(U->getOperand(0))) &&
+        isKnownNonNegative(getSCEVWithLimits(U->getOperand(1))))
+      return getURemExpr(getSCEVWithLimits(U->getOperand(0)),
+                         getSCEVWithLimits(U->getOperand(1)));
     break;
 
   case Instruction::GetElementPtr:
@@ -8100,39 +8138,39 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) {
   case Instruction::Call:
   case Instruction::Invoke:
     if (Value *RV = cast<CallBase>(U)->getReturnedArgOperand())
-      return getSCEV(RV);
+      return getSCEVWithLimits(RV);
 
     if (auto *II = dyn_cast<IntrinsicInst>(U)) {
       switch (II->getIntrinsicID()) {
       case Intrinsic::abs:
         return getAbsExpr(
-            getSCEV(II->getArgOperand(0)),
+            getSCEVWithLimits(II->getArgOperand(0)),
             /*IsNSW=*/cast<ConstantInt>(II->getArgOperand(1))->isOne());
       case Intrinsic::umax:
-        LHS = getSCEV(II->getArgOperand(0));
-        RHS = getSCEV(II->getArgOperand(1));
+        LHS = getSCEVWithLimits(II->getArgOperand(0));
+        RHS = getSCEVWithLimits(II->getArgOperand(1));
         return getUMaxExpr(LHS, RHS);
       case Intrinsic::umin:
-        LHS = getSCEV(II->getArgOperand(0));
-        RHS = getSCEV(II->getArgOperand(1));
+        LHS = getSCEVWithLimits(II->getArgOperand(0));
+        RHS = getSCEVWithLimits(II->getArgOperand(1));
         return getUMinExpr(LHS, RHS);
       case Intrinsic::smax:
-        LHS = getSCEV(II->getArgOperand(0));
-        RHS = getSCEV(II->getArgOperand(1));
+        LHS = getSCEVWithLimits(II->getArgOperand(0));
+        RHS = getSCEVWithLimits(II->getArgOperand(1));
         return getSMaxExpr(LHS, RHS);
       case Intrinsic::smin:
-        LHS = getSCEV(II->getArgOperand(0));
-        RHS = getSCEV(II->getArgOperand(1));
+        LHS = getSCEVWithLimits(II->getArgOperand(0));
+        RHS = getSCEVWithLimits(II->getArgOperand(1));
         return getSMinExpr(LHS, RHS);
       case Intrinsic::usub_sat: {
-        const SCEV *X = getSCEV(II->getArgOperand(0));
-        const SCEV *Y = getSCEV(II->getArgOperand(1));
+        const SCEV *X = getSCEVWithLimits(II->getArgOperand(0));
+        const SCEV *Y = getSCEVWithLimits(II->getArgOperand(1));
         const SCEV *ClampedY = getUMinExpr(X, Y);
         return getMinusSCEV(X, ClampedY, SCEV::FlagNUW);
       }
       case Intrinsic::uadd_sat: {
-        const SCEV *X = getSCEV(II->getArgOperand(0));
-        const SCEV *Y = getSCEV(II->getArgOperand(1));
+        const SCEV *X = getSCEVWithLimits(II->getArgOperand(0));
+        const SCEV *Y = getSCEVWithLimits(II->getArgOperand(1));
         const SCEV *ClampedX = getUMinExpr(X, getNotSCEV(Y));
         return getAddExpr(ClampedX, Y, SCEV::FlagNUW);
       }
@@ -8141,7 +8179,7 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) {
       case Intrinsic::ptr_annotation:
         // A start_loop_iterations or llvm.annotation or llvm.prt.annotation is
         // just eqivalent to the first operand for SCEV purposes.
-        return getSCEV(II->getArgOperand(0));
+        return getSCEVWithLimits(II->getArgOperand(0));
       case Intrinsic::vscale:
         return getVScale(II->getType());
       default:



More information about the llvm-commits mailing list