[llvm] [SCEV][LAA] Support multiplication overflow computation (PR #155236)
via llvm-commits
llvm-commits at lists.llvm.org
Tue Aug 26 06:06:24 PDT 2025
https://github.com/annamthomas updated https://github.com/llvm/llvm-project/pull/155236
>From 5bdf5af16f8e9581aa1c668b86c356f3afabefef Mon Sep 17 00:00:00 2001
From: Anna Thomas <anna at azul.com>
Date: Sat, 23 Aug 2025 09:47:00 -0400
Subject: [PATCH] [SCEV][LAA] Support multiplication overflow computation
Add support for identifying multiplication overflow in SCEV.
This is needed in LoopAccessAnalysis and that limitation was worked around
by 484417a.
This allows early-exit vectorization to work as expected in
vect.stats.ll test without needing the workaround.
---
llvm/lib/Analysis/Loads.cpp | 7 +++++-
llvm/lib/Analysis/LoopAccessAnalysis.cpp | 22 +++++++++++--------
llvm/lib/Analysis/ScalarEvolution.cpp | 14 +++++++++---
.../Transforms/LoopVectorize/vect.stats.ll | 2 +-
4 files changed, 31 insertions(+), 14 deletions(-)
diff --git a/llvm/lib/Analysis/Loads.cpp b/llvm/lib/Analysis/Loads.cpp
index 9a2c9ba63ec7e..7a8fbbd0fb919 100644
--- a/llvm/lib/Analysis/Loads.cpp
+++ b/llvm/lib/Analysis/Loads.cpp
@@ -26,6 +26,10 @@
using namespace llvm;
+static cl::opt<bool>
+ UseSymbolicMaxBTCForDerefInLoop("use-symbolic-maxbtc-deref-loop",
+ cl::init(false));
+
static bool isAligned(const Value *Base, Align Alignment,
const DataLayout &DL) {
return Base->getPointerAlignment(DL) >= Alignment;
@@ -332,7 +336,7 @@ bool llvm::isDereferenceableAndAlignedInLoop(
if (isa<SCEVCouldNotCompute>(MaxBECount))
return false;
- if (isa<SCEVCouldNotCompute>(BECount)) {
+ if (isa<SCEVCouldNotCompute>(BECount) && !UseSymbolicMaxBTCForDerefInLoop) {
// TODO: Support symbolic max backedge taken counts for loops without
// computable backedge taken counts.
MaxBECount =
@@ -340,6 +344,7 @@ bool llvm::isDereferenceableAndAlignedInLoop(
? SE.getPredicatedConstantMaxBackedgeTakenCount(L, *Predicates)
: SE.getConstantMaxBackedgeTakenCount(L);
}
+
const auto &[AccessStart, AccessEnd] = getStartAndEndForAccess(
L, PtrScev, LI->getType(), BECount, MaxBECount, &SE, nullptr, &DT, AC);
if (isa<SCEVCouldNotCompute>(AccessStart) ||
diff --git a/llvm/lib/Analysis/LoopAccessAnalysis.cpp b/llvm/lib/Analysis/LoopAccessAnalysis.cpp
index bceddd0325276..258fa982ed1d0 100644
--- a/llvm/lib/Analysis/LoopAccessAnalysis.cpp
+++ b/llvm/lib/Analysis/LoopAccessAnalysis.cpp
@@ -193,8 +193,9 @@ RuntimeCheckingPtrGroup::RuntimeCheckingPtrGroup(
/// Returns \p A + \p B, if it is guaranteed not to unsigned wrap. Otherwise
/// return nullptr. \p A and \p B must have the same type.
static const SCEV *addSCEVNoOverflow(const SCEV *A, const SCEV *B,
- ScalarEvolution &SE) {
- if (!SE.willNotOverflow(Instruction::Add, /*IsSigned=*/false, A, B))
+ ScalarEvolution &SE,
+ const Instruction *CtxI) {
+ if (!SE.willNotOverflow(Instruction::Add, /*IsSigned=*/false, A, B, CtxI))
return nullptr;
return SE.getAddExpr(A, B);
}
@@ -202,8 +203,9 @@ static const SCEV *addSCEVNoOverflow(const SCEV *A, const SCEV *B,
/// Returns \p A * \p B, if it is guaranteed not to unsigned wrap. Otherwise
/// return nullptr. \p A and \p B must have the same type.
static const SCEV *mulSCEVOverflow(const SCEV *A, const SCEV *B,
- ScalarEvolution &SE) {
- if (!SE.willNotOverflow(Instruction::Mul, /*IsSigned=*/false, A, B))
+ ScalarEvolution &SE,
+ const Instruction *CtxI) {
+ if (!SE.willNotOverflow(Instruction::Mul, /*IsSigned=*/false, A, B, CtxI))
return nullptr;
return SE.getMulExpr(A, B);
}
@@ -232,11 +234,12 @@ evaluatePtrAddRecAtMaxBTCWillNotWrap(const SCEVAddRecExpr *AR,
Type *WiderTy = SE.getWiderType(MaxBTC->getType(), Step->getType());
const SCEV *DerefBytesSCEV = SE.getConstant(WiderTy, DerefBytes);
+ // Context which dominates the entire loop.
+ auto *CtxI = L->getLoopPredecessor()->getTerminator();
// Check if we have a suitable dereferencable assumption we can use.
if (!StartPtrV->canBeFreed()) {
RetainedKnowledge DerefRK = getKnowledgeValidInContext(
- StartPtrV, {Attribute::Dereferenceable}, *AC,
- L->getLoopPredecessor()->getTerminator(), DT);
+ StartPtrV, {Attribute::Dereferenceable}, *AC, CtxI, DT);
if (DerefRK) {
DerefBytesSCEV = SE.getUMaxExpr(
DerefBytesSCEV, SE.getConstant(WiderTy, DerefRK.ArgValue));
@@ -260,12 +263,12 @@ evaluatePtrAddRecAtMaxBTCWillNotWrap(const SCEVAddRecExpr *AR,
SE.getMinusSCEV(AR->getStart(), StartPtr), WiderTy);
const SCEV *OffsetAtLastIter =
- mulSCEVOverflow(MaxBTC, SE.getAbsExpr(Step, /*IsNSW=*/false), SE);
+ mulSCEVOverflow(MaxBTC, SE.getAbsExpr(Step, /*IsNSW=*/false), SE, CtxI);
if (!OffsetAtLastIter)
return false;
const SCEV *OffsetEndBytes = addSCEVNoOverflow(
- OffsetAtLastIter, SE.getNoopOrZeroExtend(EltSize, WiderTy), SE);
+ OffsetAtLastIter, SE.getNoopOrZeroExtend(EltSize, WiderTy), SE, CtxI);
if (!OffsetEndBytes)
return false;
@@ -273,7 +276,8 @@ evaluatePtrAddRecAtMaxBTCWillNotWrap(const SCEVAddRecExpr *AR,
// For positive steps, check if
// (AR->getStart() - StartPtr) + (MaxBTC * Step) + EltSize <= DerefBytes,
// while making sure none of the computations unsigned wrap themselves.
- const SCEV *EndBytes = addSCEVNoOverflow(StartOffset, OffsetEndBytes, SE);
+ const SCEV *EndBytes =
+ addSCEVNoOverflow(StartOffset, OffsetEndBytes, SE, CtxI);
if (!EndBytes)
return false;
return SE.isKnownPredicate(CmpInst::ICMP_ULE, EndBytes, DerefBytesSCEV);
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index f60a1e9f22704..1b73b750846f4 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -2337,15 +2337,23 @@ bool ScalarEvolution::willNotOverflow(Instruction::BinaryOps BinOp, bool Signed,
// Can we use context to prove the fact we need?
if (!CtxI)
return false;
- // TODO: Support mul.
- if (BinOp == Instruction::Mul)
- return false;
auto *RHSC = dyn_cast<SCEVConstant>(RHS);
// TODO: Lift this limitation.
if (!RHSC)
return false;
APInt C = RHSC->getAPInt();
unsigned NumBits = C.getBitWidth();
+ if (BinOp == Instruction::Mul) {
+ // Multiplying by 0 or 1 never overflows
+ if (C.isZero() || C.isOne())
+ return true;
+ if (Signed)
+ return false;
+ APInt Limit = APInt::getMaxValue(NumBits).udiv(C);
+ // To avoid overflow, we need to make sure that LHS <= MAX / C.
+ return isKnownPredicateAt(ICmpInst::ICMP_ULE, LHS, getConstant(Limit),
+ CtxI);
+ }
bool IsSub = (BinOp == Instruction::Sub);
bool IsNegativeConst = (Signed && C.isNegative());
// Compute the direction and magnitude by which we need to check overflow.
diff --git a/llvm/test/Transforms/LoopVectorize/vect.stats.ll b/llvm/test/Transforms/LoopVectorize/vect.stats.ll
index e3240c8181519..f3695e6712952 100644
--- a/llvm/test/Transforms/LoopVectorize/vect.stats.ll
+++ b/llvm/test/Transforms/LoopVectorize/vect.stats.ll
@@ -1,4 +1,4 @@
-; RUN: opt < %s -passes=loop-vectorize -force-vector-interleave=4 -force-vector-width=4 -debug-only=loop-vectorize -enable-early-exit-vectorization --disable-output -stats -S 2>&1 | FileCheck %s
+; RUN: opt < %s -passes=loop-vectorize -force-vector-interleave=4 -force-vector-width=4 -debug-only=loop-vectorize -enable-early-exit-vectorization -use-symbolic-maxbtc-deref-loop --disable-output -stats -S 2>&1 | FileCheck %s
; REQUIRES: asserts
; We have 3 loops, two of them are vectorizable (with one being early-exit
More information about the llvm-commits
mailing list