[llvm] [DA] runtime predicates for delinearization bounds checks (PR #170713)

Sjoerd Meijer via llvm-commits llvm-commits at lists.llvm.org
Fri Dec 5 08:26:39 PST 2025


================
@@ -753,37 +753,137 @@ static bool isKnownLessThan(ScalarEvolution *SE, const SCEV *S,
   return SE->isKnownNegative(LimitedBound);
 }
 
-bool llvm::validateDelinearizationResult(ScalarEvolution &SE,
-                                         ArrayRef<const SCEV *> Sizes,
-                                         ArrayRef<const SCEV *> Subscripts,
-                                         const Value *Ptr) {
+bool llvm::validateDelinearizationResult(
+    ScalarEvolution &SE, ArrayRef<const SCEV *> Sizes,
+    ArrayRef<const SCEV *> Subscripts, const Value *Ptr,
+    SmallVectorImpl<const SCEVPredicate *> *Assume) {
   // Sizes and Subscripts are as follows:
-  //
   //   Sizes:      [UNK][S_2]...[S_n]
   //   Subscripts: [I_1][I_2]...[I_n]
   //
   // where the size of the outermost dimension is unknown (UNK).
 
-  auto AddOverflow = [&](const SCEV *A, const SCEV *B) -> const SCEV * {
-    if (!SE.willNotOverflow(Instruction::Add, /*IsSigned=*/true, A, B))
+  // Unify types of two SCEVs to the wider type.
+  auto UnifyTypes =
+      [&](const SCEV *&A,
+          const SCEV *&B) -> std::pair<const SCEV *, const SCEV *> {
+    Type *WiderType = SE.getWiderType(A->getType(), B->getType());
+    return {SE.getNoopOrSignExtend(A, WiderType),
+            SE.getNoopOrSignExtend(B, WiderType)};
+  };
+
+  // Get a type with twice the bit width of T.
+  auto GetWiderType = [&](Type *T) -> Type * {
+    unsigned BitWidth = SE.getTypeSizeInBits(T);
+    return IntegerType::get(T->getContext(), BitWidth * 2);
+  };
+
+  // Check if the result of A + B (signed) does not overflow. If it can be
+  // proven at compile-time, return the result. If it might overflow and Assume
+  // is provided, add a runtime equality predicate and return the result.
+  // Otherwise return nullptr.
+  auto AddNoOverflow = [&](const SCEV *A, const SCEV *B) -> const SCEV * {
+    std::tie(A, B) = UnifyTypes(A, B);
+    if (SE.willNotOverflow(Instruction::Add, /*IsSigned=*/true, A, B))
+      return SE.getAddExpr(A, B);
+    if (!Assume)
       return nullptr;
-    return SE.getAddExpr(A, B);
+
+    // Compute the addition in a wider type to detect overflow.
+    // If (sext A) + (sext B) == sext(A + B), then A + B does not overflow.
+    Type *OrigTy = A->getType();
+    Type *WiderTy = GetWiderType(OrigTy);
+    const SCEV *AWide = SE.getSignExtendExpr(A, WiderTy);
+    const SCEV *BWide = SE.getSignExtendExpr(B, WiderTy);
+    const SCEV *SumWide = SE.getAddExpr(AWide, BWide);
+    const SCEV *Sum = SE.getAddExpr(A, B);
+    const SCEV *SumExtended = SE.getSignExtendExpr(Sum, WiderTy);
+    // Add predicate: (sext A) + (sext B) == sext(A + B).
+    if (SumWide != SumExtended &&
+        !SE.isKnownPredicate(ICmpInst::ICMP_EQ, SumWide, SumExtended))
+      Assume->push_back(SE.getEqualPredicate(SumWide, SumExtended));
+    return Sum;
   };
 
-  auto MulOverflow = [&](const SCEV *A, const SCEV *B) -> const SCEV * {
-    if (!SE.willNotOverflow(Instruction::Mul, /*IsSigned=*/true, A, B))
+  // Check if the result of A * B (signed) does not overflow. If it can be
+  // proven at compile-time, return the result. If it might overflow and Assume
+  // is provided, add a runtime equality predicate and return the result.
+  // Otherwise return nullptr.
+  auto MulNoOverflow = [&](const SCEV *A, const SCEV *B) -> const SCEV * {
+    std::tie(A, B) = UnifyTypes(A, B);
+    if (SE.willNotOverflow(Instruction::Mul, /*IsSigned=*/true, A, B))
+      return SE.getMulExpr(A, B);
+    if (!Assume)
       return nullptr;
-    return SE.getMulExpr(A, B);
+
+    // Compute the multiplication in a wider type to detect overflow.
+    // If (sext A) * (sext B) == sext(A * B), then A * B does not overflow.
----------------
sjoerdmeijer wrote:

Same?

https://github.com/llvm/llvm-project/pull/170713


More information about the llvm-commits mailing list