[llvm] [SCEV][LAA] Support multiplication overflow computation (PR #155236)

via llvm-commits llvm-commits at lists.llvm.org
Mon Aug 25 05:03:12 PDT 2025


https://github.com/annamthomas created https://github.com/llvm/llvm-project/pull/155236

Add support for identifying multiplication overflow in SCEV.
This is needed in LoopAccessAnalysis and that limitation was worked around
by 484417a.
This patch reverts the workaround and the early-exit
vectorization works as expected in vect.stats.ll test.


>From 25b42ab6c8d711499e3bdcbdcf2bd725de308c0c Mon Sep 17 00:00:00 2001
From: Anna Thomas <anna at azul.com>
Date: Sat, 23 Aug 2025 09:47:00 -0400
Subject: [PATCH] [SCEV][LAA] Support multiplication overflow computation

Add support for identifying multiplication overflow in SCEV.
This is needed in LoopAccessAnalysis and that limitation was worked around
by 484417a.
This patch reverts the workaround and the early-exit
vectorization works as expected in vect.stats.ll test.
---
 llvm/lib/Analysis/Loads.cpp              |  8 --------
 llvm/lib/Analysis/LoopAccessAnalysis.cpp | 22 +++++++++++++---------
 llvm/lib/Analysis/ScalarEvolution.cpp    | 12 +++++++++---
 3 files changed, 22 insertions(+), 20 deletions(-)

diff --git a/llvm/lib/Analysis/Loads.cpp b/llvm/lib/Analysis/Loads.cpp
index 9a2c9ba63ec7e..b5600f2b0822f 100644
--- a/llvm/lib/Analysis/Loads.cpp
+++ b/llvm/lib/Analysis/Loads.cpp
@@ -332,14 +332,6 @@ bool llvm::isDereferenceableAndAlignedInLoop(
   if (isa<SCEVCouldNotCompute>(MaxBECount))
     return false;
 
-  if (isa<SCEVCouldNotCompute>(BECount)) {
-    // TODO: Support symbolic max backedge taken counts for loops without
-    // computable backedge taken counts.
-    MaxBECount =
-        Predicates
-            ? SE.getPredicatedConstantMaxBackedgeTakenCount(L, *Predicates)
-            : SE.getConstantMaxBackedgeTakenCount(L);
-  }
   const auto &[AccessStart, AccessEnd] = getStartAndEndForAccess(
       L, PtrScev, LI->getType(), BECount, MaxBECount, &SE, nullptr, &DT, AC);
   if (isa<SCEVCouldNotCompute>(AccessStart) ||
diff --git a/llvm/lib/Analysis/LoopAccessAnalysis.cpp b/llvm/lib/Analysis/LoopAccessAnalysis.cpp
index bceddd0325276..258fa982ed1d0 100644
--- a/llvm/lib/Analysis/LoopAccessAnalysis.cpp
+++ b/llvm/lib/Analysis/LoopAccessAnalysis.cpp
@@ -193,8 +193,9 @@ RuntimeCheckingPtrGroup::RuntimeCheckingPtrGroup(
 /// Returns \p A + \p B, if it is guaranteed not to unsigned wrap. Otherwise
 /// return nullptr. \p A and \p B must have the same type.
 static const SCEV *addSCEVNoOverflow(const SCEV *A, const SCEV *B,
-                                     ScalarEvolution &SE) {
-  if (!SE.willNotOverflow(Instruction::Add, /*IsSigned=*/false, A, B))
+                                     ScalarEvolution &SE,
+                                     const Instruction *CtxI) {
+  if (!SE.willNotOverflow(Instruction::Add, /*IsSigned=*/false, A, B, CtxI))
     return nullptr;
   return SE.getAddExpr(A, B);
 }
@@ -202,8 +203,9 @@ static const SCEV *addSCEVNoOverflow(const SCEV *A, const SCEV *B,
 /// Returns \p A * \p B, if it is guaranteed not to unsigned wrap. Otherwise
 /// return nullptr. \p A and \p B must have the same type.
 static const SCEV *mulSCEVOverflow(const SCEV *A, const SCEV *B,
-                                   ScalarEvolution &SE) {
-  if (!SE.willNotOverflow(Instruction::Mul, /*IsSigned=*/false, A, B))
+                                   ScalarEvolution &SE,
+                                   const Instruction *CtxI) {
+  if (!SE.willNotOverflow(Instruction::Mul, /*IsSigned=*/false, A, B, CtxI))
     return nullptr;
   return SE.getMulExpr(A, B);
 }
@@ -232,11 +234,12 @@ evaluatePtrAddRecAtMaxBTCWillNotWrap(const SCEVAddRecExpr *AR,
   Type *WiderTy = SE.getWiderType(MaxBTC->getType(), Step->getType());
   const SCEV *DerefBytesSCEV = SE.getConstant(WiderTy, DerefBytes);
 
+  // Context which dominates the entire loop.
+  auto *CtxI = L->getLoopPredecessor()->getTerminator();
   // Check if we have a suitable dereferencable assumption we can use.
   if (!StartPtrV->canBeFreed()) {
     RetainedKnowledge DerefRK = getKnowledgeValidInContext(
-        StartPtrV, {Attribute::Dereferenceable}, *AC,
-        L->getLoopPredecessor()->getTerminator(), DT);
+        StartPtrV, {Attribute::Dereferenceable}, *AC, CtxI, DT);
     if (DerefRK) {
       DerefBytesSCEV = SE.getUMaxExpr(
           DerefBytesSCEV, SE.getConstant(WiderTy, DerefRK.ArgValue));
@@ -260,12 +263,12 @@ evaluatePtrAddRecAtMaxBTCWillNotWrap(const SCEVAddRecExpr *AR,
       SE.getMinusSCEV(AR->getStart(), StartPtr), WiderTy);
 
   const SCEV *OffsetAtLastIter =
-      mulSCEVOverflow(MaxBTC, SE.getAbsExpr(Step, /*IsNSW=*/false), SE);
+      mulSCEVOverflow(MaxBTC, SE.getAbsExpr(Step, /*IsNSW=*/false), SE, CtxI);
   if (!OffsetAtLastIter)
     return false;
 
   const SCEV *OffsetEndBytes = addSCEVNoOverflow(
-      OffsetAtLastIter, SE.getNoopOrZeroExtend(EltSize, WiderTy), SE);
+      OffsetAtLastIter, SE.getNoopOrZeroExtend(EltSize, WiderTy), SE, CtxI);
   if (!OffsetEndBytes)
     return false;
 
@@ -273,7 +276,8 @@ evaluatePtrAddRecAtMaxBTCWillNotWrap(const SCEVAddRecExpr *AR,
     // For positive steps, check if
     //  (AR->getStart() - StartPtr) + (MaxBTC  * Step) + EltSize <= DerefBytes,
     // while making sure none of the computations unsigned wrap themselves.
-    const SCEV *EndBytes = addSCEVNoOverflow(StartOffset, OffsetEndBytes, SE);
+    const SCEV *EndBytes =
+        addSCEVNoOverflow(StartOffset, OffsetEndBytes, SE, CtxI);
     if (!EndBytes)
       return false;
     return SE.isKnownPredicate(CmpInst::ICMP_ULE, EndBytes, DerefBytesSCEV);
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index f60a1e9f22704..2b638003043d9 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -2337,15 +2337,21 @@ bool ScalarEvolution::willNotOverflow(Instruction::BinaryOps BinOp, bool Signed,
   // Can we use context to prove the fact we need?
   if (!CtxI)
     return false;
-  // TODO: Support mul.
-  if (BinOp == Instruction::Mul)
-    return false;
   auto *RHSC = dyn_cast<SCEVConstant>(RHS);
   // TODO: Lift this limitation.
   if (!RHSC)
     return false;
   APInt C = RHSC->getAPInt();
   unsigned NumBits = C.getBitWidth();
+  if (BinOp == Instruction::Mul) {
+    // Multiplying by 0 or 1 never overflows
+    if (C.isZero() || C.isOne())
+      return true;
+    auto Pred = Signed ? ICmpInst::ICMP_SLE : ICmpInst::ICMP_ULE;
+    APInt Limit = APInt::getMaxValue(NumBits).udiv(C);
+    // To avoid overflow, we need to make sure that LHS <= MAX / C.
+    return isKnownPredicateAt(Pred, LHS, getConstant(Limit), CtxI);
+  }
   bool IsSub = (BinOp == Instruction::Sub);
   bool IsNegativeConst = (Signed && C.isNegative());
   // Compute the direction and magnitude by which we need to check overflow.



More information about the llvm-commits mailing list