[llvm] [SCEV][LAA] Support multiplication overflow computation (PR #155236)

via llvm-commits llvm-commits at lists.llvm.org
Mon Aug 25 05:03:46 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-analysis

Author: None (annamthomas)

<details>
<summary>Changes</summary>

Add support for identifying multiplication overflow in SCEV.
This is needed in LoopAccessAnalysis and that limitation was worked around
by 484417a.
This patch reverts the workaround and the early-exit
vectorization works as expected in vect.stats.ll test.


---
Full diff: https://github.com/llvm/llvm-project/pull/155236.diff


3 Files Affected:

- (modified) llvm/lib/Analysis/Loads.cpp (-8) 
- (modified) llvm/lib/Analysis/LoopAccessAnalysis.cpp (+13-9) 
- (modified) llvm/lib/Analysis/ScalarEvolution.cpp (+9-3) 


``````````diff
diff --git a/llvm/lib/Analysis/Loads.cpp b/llvm/lib/Analysis/Loads.cpp
index 9a2c9ba63ec7e..b5600f2b0822f 100644
--- a/llvm/lib/Analysis/Loads.cpp
+++ b/llvm/lib/Analysis/Loads.cpp
@@ -332,14 +332,6 @@ bool llvm::isDereferenceableAndAlignedInLoop(
   if (isa<SCEVCouldNotCompute>(MaxBECount))
     return false;
 
-  if (isa<SCEVCouldNotCompute>(BECount)) {
-    // TODO: Support symbolic max backedge taken counts for loops without
-    // computable backedge taken counts.
-    MaxBECount =
-        Predicates
-            ? SE.getPredicatedConstantMaxBackedgeTakenCount(L, *Predicates)
-            : SE.getConstantMaxBackedgeTakenCount(L);
-  }
   const auto &[AccessStart, AccessEnd] = getStartAndEndForAccess(
       L, PtrScev, LI->getType(), BECount, MaxBECount, &SE, nullptr, &DT, AC);
   if (isa<SCEVCouldNotCompute>(AccessStart) ||
diff --git a/llvm/lib/Analysis/LoopAccessAnalysis.cpp b/llvm/lib/Analysis/LoopAccessAnalysis.cpp
index bceddd0325276..258fa982ed1d0 100644
--- a/llvm/lib/Analysis/LoopAccessAnalysis.cpp
+++ b/llvm/lib/Analysis/LoopAccessAnalysis.cpp
@@ -193,8 +193,9 @@ RuntimeCheckingPtrGroup::RuntimeCheckingPtrGroup(
 /// Returns \p A + \p B, if it is guaranteed not to unsigned wrap. Otherwise
 /// return nullptr. \p A and \p B must have the same type.
 static const SCEV *addSCEVNoOverflow(const SCEV *A, const SCEV *B,
-                                     ScalarEvolution &SE) {
-  if (!SE.willNotOverflow(Instruction::Add, /*IsSigned=*/false, A, B))
+                                     ScalarEvolution &SE,
+                                     const Instruction *CtxI) {
+  if (!SE.willNotOverflow(Instruction::Add, /*IsSigned=*/false, A, B, CtxI))
     return nullptr;
   return SE.getAddExpr(A, B);
 }
@@ -202,8 +203,9 @@ static const SCEV *addSCEVNoOverflow(const SCEV *A, const SCEV *B,
 /// Returns \p A * \p B, if it is guaranteed not to unsigned wrap. Otherwise
 /// return nullptr. \p A and \p B must have the same type.
 static const SCEV *mulSCEVOverflow(const SCEV *A, const SCEV *B,
-                                   ScalarEvolution &SE) {
-  if (!SE.willNotOverflow(Instruction::Mul, /*IsSigned=*/false, A, B))
+                                   ScalarEvolution &SE,
+                                   const Instruction *CtxI) {
+  if (!SE.willNotOverflow(Instruction::Mul, /*IsSigned=*/false, A, B, CtxI))
     return nullptr;
   return SE.getMulExpr(A, B);
 }
@@ -232,11 +234,12 @@ evaluatePtrAddRecAtMaxBTCWillNotWrap(const SCEVAddRecExpr *AR,
   Type *WiderTy = SE.getWiderType(MaxBTC->getType(), Step->getType());
   const SCEV *DerefBytesSCEV = SE.getConstant(WiderTy, DerefBytes);
 
+  // Context which dominates the entire loop.
+  auto *CtxI = L->getLoopPredecessor()->getTerminator();
   // Check if we have a suitable dereferencable assumption we can use.
   if (!StartPtrV->canBeFreed()) {
     RetainedKnowledge DerefRK = getKnowledgeValidInContext(
-        StartPtrV, {Attribute::Dereferenceable}, *AC,
-        L->getLoopPredecessor()->getTerminator(), DT);
+        StartPtrV, {Attribute::Dereferenceable}, *AC, CtxI, DT);
     if (DerefRK) {
       DerefBytesSCEV = SE.getUMaxExpr(
           DerefBytesSCEV, SE.getConstant(WiderTy, DerefRK.ArgValue));
@@ -260,12 +263,12 @@ evaluatePtrAddRecAtMaxBTCWillNotWrap(const SCEVAddRecExpr *AR,
       SE.getMinusSCEV(AR->getStart(), StartPtr), WiderTy);
 
   const SCEV *OffsetAtLastIter =
-      mulSCEVOverflow(MaxBTC, SE.getAbsExpr(Step, /*IsNSW=*/false), SE);
+      mulSCEVOverflow(MaxBTC, SE.getAbsExpr(Step, /*IsNSW=*/false), SE, CtxI);
   if (!OffsetAtLastIter)
     return false;
 
   const SCEV *OffsetEndBytes = addSCEVNoOverflow(
-      OffsetAtLastIter, SE.getNoopOrZeroExtend(EltSize, WiderTy), SE);
+      OffsetAtLastIter, SE.getNoopOrZeroExtend(EltSize, WiderTy), SE, CtxI);
   if (!OffsetEndBytes)
     return false;
 
@@ -273,7 +276,8 @@ evaluatePtrAddRecAtMaxBTCWillNotWrap(const SCEVAddRecExpr *AR,
     // For positive steps, check if
     //  (AR->getStart() - StartPtr) + (MaxBTC  * Step) + EltSize <= DerefBytes,
     // while making sure none of the computations unsigned wrap themselves.
-    const SCEV *EndBytes = addSCEVNoOverflow(StartOffset, OffsetEndBytes, SE);
+    const SCEV *EndBytes =
+        addSCEVNoOverflow(StartOffset, OffsetEndBytes, SE, CtxI);
     if (!EndBytes)
       return false;
     return SE.isKnownPredicate(CmpInst::ICMP_ULE, EndBytes, DerefBytesSCEV);
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index f60a1e9f22704..2b638003043d9 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -2337,15 +2337,21 @@ bool ScalarEvolution::willNotOverflow(Instruction::BinaryOps BinOp, bool Signed,
   // Can we use context to prove the fact we need?
   if (!CtxI)
     return false;
-  // TODO: Support mul.
-  if (BinOp == Instruction::Mul)
-    return false;
   auto *RHSC = dyn_cast<SCEVConstant>(RHS);
   // TODO: Lift this limitation.
   if (!RHSC)
     return false;
   APInt C = RHSC->getAPInt();
   unsigned NumBits = C.getBitWidth();
+  if (BinOp == Instruction::Mul) {
+    // Multiplying by 0 or 1 never overflows
+    if (C.isZero() || C.isOne())
+      return true;
+    auto Pred = Signed ? ICmpInst::ICMP_SLE : ICmpInst::ICMP_ULE;
+    APInt Limit = APInt::getMaxValue(NumBits).udiv(C);
+    // To avoid overflow, we need to make sure that LHS <= MAX / C.
+    return isKnownPredicateAt(Pred, LHS, getConstant(Limit), CtxI);
+  }
   bool IsSub = (BinOp == Instruction::Sub);
   bool IsNegativeConst = (Signed && C.isNegative());
   // Compute the direction and magnitude by which we need to check overflow.

``````````

</details>


https://github.com/llvm/llvm-project/pull/155236


More information about the llvm-commits mailing list