[llvm] [DA] runtime predicates for delinearization bounds checks (PR #170713)
Sjoerd Meijer via llvm-commits
llvm-commits at lists.llvm.org
Fri Dec 5 08:26:39 PST 2025
================
@@ -753,37 +753,137 @@ static bool isKnownLessThan(ScalarEvolution *SE, const SCEV *S,
return SE->isKnownNegative(LimitedBound);
}
-bool llvm::validateDelinearizationResult(ScalarEvolution &SE,
- ArrayRef<const SCEV *> Sizes,
- ArrayRef<const SCEV *> Subscripts,
- const Value *Ptr) {
+bool llvm::validateDelinearizationResult(
+ ScalarEvolution &SE, ArrayRef<const SCEV *> Sizes,
+ ArrayRef<const SCEV *> Subscripts, const Value *Ptr,
+ SmallVectorImpl<const SCEVPredicate *> *Assume) {
// Sizes and Subscripts are as follows:
- //
// Sizes: [UNK][S_2]...[S_n]
// Subscripts: [I_1][I_2]...[I_n]
//
// where the size of the outermost dimension is unknown (UNK).
- auto AddOverflow = [&](const SCEV *A, const SCEV *B) -> const SCEV * {
- if (!SE.willNotOverflow(Instruction::Add, /*IsSigned=*/true, A, B))
+ // Unify types of two SCEVs to the wider type.
+ auto UnifyTypes =
+ [&](const SCEV *&A,
+ const SCEV *&B) -> std::pair<const SCEV *, const SCEV *> {
+ Type *WiderType = SE.getWiderType(A->getType(), B->getType());
+ return {SE.getNoopOrSignExtend(A, WiderType),
+ SE.getNoopOrSignExtend(B, WiderType)};
+ };
+
+ // Get a type with twice the bit width of T.
+ auto GetWiderType = [&](Type *T) -> Type * {
+ unsigned BitWidth = SE.getTypeSizeInBits(T);
+ return IntegerType::get(T->getContext(), BitWidth * 2);
+ };
+
+ // Check if the result of A + B (signed) does not overflow. If it can be
+ // proven at compile-time, return the result. If it might overflow and Assume
+ // is provided, add a runtime equality predicate and return the result.
+ // Otherwise return nullptr.
+ auto AddNoOverflow = [&](const SCEV *A, const SCEV *B) -> const SCEV * {
+ std::tie(A, B) = UnifyTypes(A, B);
+ if (SE.willNotOverflow(Instruction::Add, /*IsSigned=*/true, A, B))
+ return SE.getAddExpr(A, B);
+ if (!Assume)
return nullptr;
- return SE.getAddExpr(A, B);
+
+ // Compute the addition in a wider type to detect overflow.
+ // If (sext A) + (sext B) == sext(A + B), then A + B does not overflow.
+ Type *OrigTy = A->getType();
+ Type *WiderTy = GetWiderType(OrigTy);
+ const SCEV *AWide = SE.getSignExtendExpr(A, WiderTy);
+ const SCEV *BWide = SE.getSignExtendExpr(B, WiderTy);
+ const SCEV *SumWide = SE.getAddExpr(AWide, BWide);
+ const SCEV *Sum = SE.getAddExpr(A, B);
+ const SCEV *SumExtended = SE.getSignExtendExpr(Sum, WiderTy);
+ // Add predicate: (sext A) + (sext B) == sext(A + B).
+ if (SumWide != SumExtended &&
+ !SE.isKnownPredicate(ICmpInst::ICMP_EQ, SumWide, SumExtended))
+ Assume->push_back(SE.getEqualPredicate(SumWide, SumExtended));
+ return Sum;
};
- auto MulOverflow = [&](const SCEV *A, const SCEV *B) -> const SCEV * {
- if (!SE.willNotOverflow(Instruction::Mul, /*IsSigned=*/true, A, B))
+ // Check if the result of A * B (signed) does not overflow. If it can be
+ // proven at compile-time, return the result. If it might overflow and Assume
+ // is provided, add a runtime equality predicate and return the result.
+ // Otherwise return nullptr.
+ auto MulNoOverflow = [&](const SCEV *A, const SCEV *B) -> const SCEV * {
+ std::tie(A, B) = UnifyTypes(A, B);
+ if (SE.willNotOverflow(Instruction::Mul, /*IsSigned=*/true, A, B))
+ return SE.getMulExpr(A, B);
+ if (!Assume)
return nullptr;
- return SE.getMulExpr(A, B);
+
+ // Compute the multiplication in a wider type to detect overflow.
+ // If (sext A) * (sext B) == sext(A * B), then A * B does not overflow.
----------------
sjoerdmeijer wrote:
Same?
https://github.com/llvm/llvm-project/pull/170713
More information about the llvm-commits
mailing list