[llvm] [ConstraintElim] Do not allow overflows in `Decomposition` (PR #140541)

Yingwei Zheng via llvm-commits llvm-commits at lists.llvm.org
Mon May 19 06:45:14 PDT 2025


https://github.com/dtcxzyw created https://github.com/llvm/llvm-project/pull/140541

Consider the following case:
```
define i1 @pr140481(i32 %x) {
  %cond = icmp slt i32 %x, 0
  call void @llvm.assume(i1 %cond)
  %add = add nsw i32 %x, 5001000
  %mul1 = mul nsw i32 %add, -5001000
  %mul2 = mul nsw i32 %mul1, 5001000
  %cmp2 = icmp sgt i32 %mul2, 0
  ret i1 %cmp2
}
```
Before this patch, `decompose(%mul2)` returns `-25010001000000 * %x + 4052193514966861312`.
Therefore, `%cmp2` will be simplified into true because `%x s< 0 && -25010001000000 * %x + 4052193514966861312 s<= 0` is unsat.

It is incorrect since the offset `-25010001000000 * 5001000 -> 4052193514966861312` signed wraps.
This patch treats a decomposition as invalid if overflows occur when computing coefficients.

>From 8cb7f7df1ce079ceecd290a9beb5af1562a7e037 Mon Sep 17 00:00:00 2001
From: Yingwei Zheng <dtcxzyw2333 at gmail.com>
Date: Mon, 19 May 2025 21:21:39 +0800
Subject: [PATCH 1/2] [ConstraintElim] Add pre-commit tests. NFC.

---
 .../constraint-overflow.ll                    | 21 +++++++++++++++++++
 1 file changed, 21 insertions(+)

diff --git a/llvm/test/Transforms/ConstraintElimination/constraint-overflow.ll b/llvm/test/Transforms/ConstraintElimination/constraint-overflow.ll
index 57b7b11be0cf1..5dc9ade756d49 100644
--- a/llvm/test/Transforms/ConstraintElimination/constraint-overflow.ll
+++ b/llvm/test/Transforms/ConstraintElimination/constraint-overflow.ll
@@ -52,3 +52,24 @@ entry:
   %c = icmp slt i64 0, %sub
   ret i1 %c
 }
+
+define i1 @pr140481(i32 %x) {
+; CHECK-LABEL: define i1 @pr140481(
+; CHECK-SAME: i32 [[X:%.*]]) {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[COND:%.*]] = icmp slt i32 [[X]], 0
+; CHECK-NEXT:    call void @llvm.assume(i1 [[COND]])
+; CHECK-NEXT:    [[ADD:%.*]] = add nsw i32 [[X]], 5001000
+; CHECK-NEXT:    [[MUL1:%.*]] = mul nsw i32 [[ADD]], -5001000
+; CHECK-NEXT:    [[MUL2:%.*]] = mul nsw i32 [[MUL1]], 5001000
+; CHECK-NEXT:    ret i1 true
+;
+entry:
+  %cond = icmp slt i32 %x, 0
+  call void @llvm.assume(i1 %cond)
+  %add = add nsw i32 %x, 5001000
+  %mul1 = mul nsw i32 %add, -5001000
+  %mul2 = mul nsw i32 %mul1, 5001000
+  %cmp2 = icmp sgt i32 %mul2, 0
+  ret i1 %cmp2
+}

>From 854df22efb6c18b88e29f23def8c3b3643083ba7 Mon Sep 17 00:00:00 2001
From: Yingwei Zheng <dtcxzyw2333 at gmail.com>
Date: Mon, 19 May 2025 21:31:32 +0800
Subject: [PATCH 2/2] [ConstraintElim] Do not allow overflows in Decomposition

---
 .../Scalar/ConstraintElimination.cpp          | 102 +++++++++---------
 .../constraint-overflow.ll                    |   3 +-
 2 files changed, 56 insertions(+), 49 deletions(-)

diff --git a/llvm/lib/Transforms/Scalar/ConstraintElimination.cpp b/llvm/lib/Transforms/Scalar/ConstraintElimination.cpp
index da5be383df15c..40d39f5455994 100644
--- a/llvm/lib/Transforms/Scalar/ConstraintElimination.cpp
+++ b/llvm/lib/Transforms/Scalar/ConstraintElimination.cpp
@@ -64,20 +64,6 @@ static cl::opt<bool> DumpReproducers(
 static int64_t MaxConstraintValue = std::numeric_limits<int64_t>::max();
 static int64_t MinSignedConstraintValue = std::numeric_limits<int64_t>::min();
 
-// A helper to multiply 2 signed integers where overflowing is allowed.
-static int64_t multiplyWithOverflow(int64_t A, int64_t B) {
-  int64_t Result;
-  MulOverflow(A, B, Result);
-  return Result;
-}
-
-// A helper to add 2 signed integers where overflowing is allowed.
-static int64_t addWithOverflow(int64_t A, int64_t B) {
-  int64_t Result;
-  AddOverflow(A, B, Result);
-  return Result;
-}
-
 static Instruction *getContextInstForUse(Use &U) {
   Instruction *UserI = cast<Instruction>(U.getUser());
   if (auto *Phi = dyn_cast<PHINode>(UserI))
@@ -366,26 +352,38 @@ struct Decomposition {
   Decomposition(int64_t Offset, ArrayRef<DecompEntry> Vars)
       : Offset(Offset), Vars(Vars) {}
 
-  void add(int64_t OtherOffset) {
-    Offset = addWithOverflow(Offset, OtherOffset);
+  // Return true if the new decomposition is invalid.
+  [[nodiscard]] bool add(int64_t OtherOffset) {
+    return AddOverflow(Offset, OtherOffset, Offset);
   }
 
-  void add(const Decomposition &Other) {
-    add(Other.Offset);
+  // Return true if the new decomposition is invalid.
+  [[nodiscard]] bool add(const Decomposition &Other) {
+    if (add(Other.Offset))
+      return true;
     append_range(Vars, Other.Vars);
+    return false;
   }
 
-  void sub(const Decomposition &Other) {
+  // Return true if the new decomposition is invalid.
+  [[nodiscard]] bool sub(const Decomposition &Other) {
     Decomposition Tmp = Other;
-    Tmp.mul(-1);
-    add(Tmp.Offset);
+    if (Tmp.mul(-1))
+      return true;
+    if (add(Tmp.Offset))
+      return true;
     append_range(Vars, Tmp.Vars);
+    return false;
   }
 
-  void mul(int64_t Factor) {
-    Offset = multiplyWithOverflow(Offset, Factor);
+  // Return true if the new decomposition is invalid.
+  [[nodiscard]] bool mul(int64_t Factor) {
+    if (MulOverflow(Offset, Factor, Offset))
+      return true;
     for (auto &Var : Vars)
-      Var.Coefficient = multiplyWithOverflow(Var.Coefficient, Factor);
+      if (MulOverflow(Var.Coefficient, Factor, Var.Coefficient))
+        return true;
+    return false;
   }
 };
 
@@ -467,8 +465,10 @@ static Decomposition decomposeGEP(GEPOperator &GEP,
   Decomposition Result(ConstantOffset.getSExtValue(), DecompEntry(1, BasePtr));
   for (auto [Index, Scale] : VariableOffsets) {
     auto IdxResult = decompose(Index, Preconditions, IsSigned, DL);
-    IdxResult.mul(Scale.getSExtValue());
-    Result.add(IdxResult);
+    if (IdxResult.mul(Scale.getSExtValue()))
+      return &GEP;
+    if (Result.add(IdxResult))
+      return &GEP;
 
     if (!NW.hasNoUnsignedWrap()) {
       // Try to prove nuw from nusw and nneg.
@@ -488,11 +488,13 @@ static Decomposition decompose(Value *V,
                                SmallVectorImpl<ConditionTy> &Preconditions,
                                bool IsSigned, const DataLayout &DL) {
 
-  auto MergeResults = [&Preconditions, IsSigned, &DL](Value *A, Value *B,
-                                                      bool IsSignedB) {
+  auto MergeResults = [&Preconditions, IsSigned,
+                       &DL](Value *A, Value *B,
+                            bool IsSignedB) -> std::optional<Decomposition> {
     auto ResA = decompose(A, Preconditions, IsSigned, DL);
     auto ResB = decompose(B, Preconditions, IsSignedB, DL);
-    ResA.add(ResB);
+    if (ResA.add(ResB))
+      return std::nullopt;
     return ResA;
   };
 
@@ -534,20 +536,21 @@ static Decomposition decompose(Value *V,
     }
 
     if (match(V, m_NSWAdd(m_Value(Op0), m_Value(Op1))))
-      return MergeResults(Op0, Op1, IsSigned);
+      if (auto Decomp = MergeResults(Op0, Op1, IsSigned))
+        return *Decomp;
 
     if (match(V, m_NSWSub(m_Value(Op0), m_Value(Op1)))) {
       auto ResA = decompose(Op0, Preconditions, IsSigned, DL);
       auto ResB = decompose(Op1, Preconditions, IsSigned, DL);
-      ResA.sub(ResB);
-      return ResA;
+      if (!ResA.sub(ResB))
+        return ResA;
     }
 
     ConstantInt *CI;
     if (match(V, m_NSWMul(m_Value(Op0), m_ConstantInt(CI))) && canUseSExt(CI)) {
       auto Result = decompose(Op0, Preconditions, IsSigned, DL);
-      Result.mul(CI->getSExtValue());
-      return Result;
+      if (!Result.mul(CI->getSExtValue()))
+        return Result;
     }
 
     // (shl nsw x, shift) is (mul nsw x, (1<<shift)), with the exception of
@@ -557,8 +560,8 @@ static Decomposition decompose(Value *V,
       if (Shift < Ty->getIntegerBitWidth() - 1) {
         assert(Shift < 64 && "Would overflow");
         auto Result = decompose(Op0, Preconditions, IsSigned, DL);
-        Result.mul(int64_t(1) << Shift);
-        return Result;
+        if (!Result.mul(int64_t(1) << Shift))
+          return Result;
       }
     }
 
@@ -592,9 +595,9 @@ static Decomposition decompose(Value *V,
 
   Value *Op1;
   ConstantInt *CI;
-  if (match(V, m_NUWAdd(m_Value(Op0), m_Value(Op1)))) {
-    return MergeResults(Op0, Op1, IsSigned);
-  }
+  if (match(V, m_NUWAdd(m_Value(Op0), m_Value(Op1))))
+    if (auto Decomp = MergeResults(Op0, Op1, IsSigned))
+      return *Decomp;
   if (match(V, m_NSWAdd(m_Value(Op0), m_Value(Op1)))) {
     if (!isKnownNonNegative(Op0, DL))
       Preconditions.emplace_back(CmpInst::ICMP_SGE, Op0,
@@ -603,7 +606,8 @@ static Decomposition decompose(Value *V,
       Preconditions.emplace_back(CmpInst::ICMP_SGE, Op1,
                                  ConstantInt::get(Op1->getType(), 0));
 
-    return MergeResults(Op0, Op1, IsSigned);
+    if (auto Decomp = MergeResults(Op0, Op1, IsSigned))
+      return *Decomp;
   }
 
   if (match(V, m_Add(m_Value(Op0), m_ConstantInt(CI))) && CI->isNegative() &&
@@ -611,33 +615,35 @@ static Decomposition decompose(Value *V,
     Preconditions.emplace_back(
         CmpInst::ICMP_UGE, Op0,
         ConstantInt::get(Op0->getType(), CI->getSExtValue() * -1));
-    return MergeResults(Op0, CI, true);
+    if (auto Decomp = MergeResults(Op0, CI, true))
+      return *Decomp;
   }
 
   // Decompose or as an add if there are no common bits between the operands.
   if (match(V, m_DisjointOr(m_Value(Op0), m_ConstantInt(CI))))
-    return MergeResults(Op0, CI, IsSigned);
+    if (auto Decomp = MergeResults(Op0, CI, IsSigned))
+      return *Decomp;
 
   if (match(V, m_NUWShl(m_Value(Op1), m_ConstantInt(CI))) && canUseSExt(CI)) {
     if (CI->getSExtValue() < 0 || CI->getSExtValue() >= 64)
       return {V, IsKnownNonNegative};
     auto Result = decompose(Op1, Preconditions, IsSigned, DL);
-    Result.mul(int64_t{1} << CI->getSExtValue());
-    return Result;
+    if (!Result.mul(int64_t{1} << CI->getSExtValue()))
+      return Result;
   }
 
   if (match(V, m_NUWMul(m_Value(Op1), m_ConstantInt(CI))) && canUseSExt(CI) &&
       (!CI->isNegative())) {
     auto Result = decompose(Op1, Preconditions, IsSigned, DL);
-    Result.mul(CI->getSExtValue());
-    return Result;
+    if (!Result.mul(CI->getSExtValue()))
+      return Result;
   }
 
   if (match(V, m_NUWSub(m_Value(Op0), m_Value(Op1)))) {
     auto ResA = decompose(Op0, Preconditions, IsSigned, DL);
     auto ResB = decompose(Op1, Preconditions, IsSigned, DL);
-    ResA.sub(ResB);
-    return ResA;
+    if (!ResA.sub(ResB))
+      return ResA;
   }
 
   return {V, IsKnownNonNegative};
diff --git a/llvm/test/Transforms/ConstraintElimination/constraint-overflow.ll b/llvm/test/Transforms/ConstraintElimination/constraint-overflow.ll
index 5dc9ade756d49..f36ac311878b2 100644
--- a/llvm/test/Transforms/ConstraintElimination/constraint-overflow.ll
+++ b/llvm/test/Transforms/ConstraintElimination/constraint-overflow.ll
@@ -62,7 +62,8 @@ define i1 @pr140481(i32 %x) {
 ; CHECK-NEXT:    [[ADD:%.*]] = add nsw i32 [[X]], 5001000
 ; CHECK-NEXT:    [[MUL1:%.*]] = mul nsw i32 [[ADD]], -5001000
 ; CHECK-NEXT:    [[MUL2:%.*]] = mul nsw i32 [[MUL1]], 5001000
-; CHECK-NEXT:    ret i1 true
+; CHECK-NEXT:    [[CMP2:%.*]] = icmp sgt i32 [[MUL2]], 0
+; CHECK-NEXT:    ret i1 [[CMP2]]
 ;
 entry:
   %cond = icmp slt i32 %x, 0



More information about the llvm-commits mailing list