[llvm] [PredicateInfo] Infer operand bound from mul nuw square predicates (PR #173127)

Ken Matsui via llvm-commits llvm-commits at lists.llvm.org
Sat Dec 20 21:08:29 PST 2025


https://github.com/ken-matsui updated https://github.com/llvm/llvm-project/pull/173127

>From 00063b7e34b4602fcced54df62d70092d3d82537 Mon Sep 17 00:00:00 2001
From: Ken Matsui <github at kmts.me>
Date: Wed, 17 Dec 2025 19:21:16 -0500
Subject: [PATCH 1/2] Add baseline tests for upcoming patch

---
 .../mul-nuw-square.ll                         | 91 +++++++++++++++++++
 1 file changed, 91 insertions(+)
 create mode 100644 llvm/test/Transforms/CorrelatedValuePropagation/mul-nuw-square.ll

diff --git a/llvm/test/Transforms/CorrelatedValuePropagation/mul-nuw-square.ll b/llvm/test/Transforms/CorrelatedValuePropagation/mul-nuw-square.ll
new file mode 100644
index 0000000000000..f1bda61b42cc6
--- /dev/null
+++ b/llvm/test/Transforms/CorrelatedValuePropagation/mul-nuw-square.ll
@@ -0,0 +1,91 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt -passes=correlated-propagation -S < %s | FileCheck %s
+
+declare void @llvm.assume(i1)
+
+define i1 @assume_mul_nuw_square_i8(i8 %s) {
+; CHECK-LABEL: @assume_mul_nuw_square_i8(
+; CHECK-NEXT:    [[MUL:%.*]] = mul nuw i8 [[S:%.*]], [[S]]
+; CHECK-NEXT:    [[COND:%.*]] = icmp ule i8 [[MUL]], 120
+; CHECK-NEXT:    call void @llvm.assume(i1 [[COND]])
+; CHECK-NEXT:    [[CMP:%.*]] = icmp ult i8 [[S]], 16
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+  %mul = mul nuw i8 %s, %s
+  %cond = icmp ule i8 %mul, 120
+  call void @llvm.assume(i1 %cond)
+  %cmp = icmp ult i8 %s, 16
+  ret i1 %cmp
+}
+
+define i1 @assume_mul_nuw_square_i5(i5 %s) {
+; CHECK-LABEL: @assume_mul_nuw_square_i5(
+; CHECK-NEXT:    [[MUL:%.*]] = mul nuw i5 [[S:%.*]], [[S]]
+; CHECK-NEXT:    [[COND:%.*]] = icmp ult i5 [[MUL]], 15
+; CHECK-NEXT:    call void @llvm.assume(i1 [[COND]])
+; CHECK-NEXT:    [[CMP:%.*]] = icmp ult i5 [[S]], 8
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+  %mul = mul nuw i5 %s, %s
+  %cond = icmp ult i5 %mul, 15
+  call void @llvm.assume(i1 %cond)
+  %cmp = icmp ult i5 %s, 8
+  ret i1 %cmp
+}
+
+define i1 @branch_mul_nuw_square(i8 %s, i8 %num) {
+; CHECK-LABEL: @branch_mul_nuw_square(
+; CHECK-NEXT:    [[MUL:%.*]] = mul nuw i8 [[S:%.*]], [[S]]
+; CHECK-NEXT:    [[COND:%.*]] = icmp ule i8 [[MUL]], [[NUM:%.*]]
+; CHECK-NEXT:    br i1 [[COND]], label [[TRUE:%.*]], label [[FALSE:%.*]]
+; CHECK:       true:
+; CHECK-NEXT:    [[CMP:%.*]] = icmp ult i8 [[S]], 16
+; CHECK-NEXT:    ret i1 [[CMP]]
+; CHECK:       false:
+; CHECK-NEXT:    [[CMP2:%.*]] = icmp ult i8 [[S]], 16
+; CHECK-NEXT:    ret i1 [[CMP2]]
+;
+  %mul = mul nuw i8 %s, %s
+  %cond = icmp ule i8 %mul, %num
+  br i1 %cond, label %true, label %false
+
+true:
+  %cmp = icmp ult i8 %s, 16
+  ret i1 %cmp
+
+false:
+  %cmp2 = icmp ult i8 %s, 16
+  ret i1 %cmp2
+}
+
+; negative test: missing nuw on the multiply.
+define i1 @assume_mul_square_no_nuw(i8 %s) {
+; CHECK-LABEL: @assume_mul_square_no_nuw(
+; CHECK-NEXT:    [[MUL:%.*]] = mul i8 [[S:%.*]], [[S]]
+; CHECK-NEXT:    [[COND:%.*]] = icmp ule i8 [[MUL]], 120
+; CHECK-NEXT:    call void @llvm.assume(i1 [[COND]])
+; CHECK-NEXT:    [[CMP:%.*]] = icmp ult i8 [[S]], 16
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+  %mul = mul i8 %s, %s
+  %cond = icmp ule i8 %mul, 120
+  call void @llvm.assume(i1 %cond)
+  %cmp = icmp ult i8 %s, 16
+  ret i1 %cmp
+}
+
+; negative test: multiply is not a square.
+define i1 @assume_mul_nuw_not_square(i8 %s, i8 %t) {
+; CHECK-LABEL: @assume_mul_nuw_not_square(
+; CHECK-NEXT:    [[MUL:%.*]] = mul nuw i8 [[S:%.*]], [[T:%.*]]
+; CHECK-NEXT:    [[COND:%.*]] = icmp ule i8 [[MUL]], 120
+; CHECK-NEXT:    call void @llvm.assume(i1 [[COND]])
+; CHECK-NEXT:    [[CMP:%.*]] = icmp ult i8 [[S]], 16
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+  %mul = mul nuw i8 %s, %t
+  %cond = icmp ule i8 %mul, 120
+  call void @llvm.assume(i1 %cond)
+  %cmp = icmp ult i8 %s, 16
+  ret i1 %cmp
+}

>From b11a07e0a3b5b66e22869f80b3226201b513471b Mon Sep 17 00:00:00 2001
From: Ken Matsui <github at kmts.me>
Date: Sun, 21 Dec 2025 00:07:45 -0500
Subject: [PATCH 2/2] [ValueTracking] Infer operand bound from mul nuw square
 predicates

A mul nuw X, X used in an assume/branch condition cannot overflow (or
the condition would be poison, which is UB for assumes and control
flow), which implies:

  X < 2^ceil(bitwidth(X)/2) (e.g., i16: X < 256).
---
 llvm/include/llvm/Analysis/ValueTracking.h    |  5 +++
 llvm/lib/Analysis/LazyValueInfo.cpp           |  3 ++
 llvm/lib/Analysis/ValueTracking.cpp           | 40 ++++++++++++++++++-
 .../mul-nuw-square.ll                         | 12 ++----
 4 files changed, 51 insertions(+), 9 deletions(-)

diff --git a/llvm/include/llvm/Analysis/ValueTracking.h b/llvm/include/llvm/Analysis/ValueTracking.h
index b730a36488780..0c87fa2522b4b 100644
--- a/llvm/include/llvm/Analysis/ValueTracking.h
+++ b/llvm/include/llvm/Analysis/ValueTracking.h
@@ -663,6 +663,11 @@ LLVM_ABI bool isOverflowIntrinsicNoWrap(const WithOverflowInst *WO,
 /// based on the vscale_range function attribute.
 LLVM_ABI ConstantRange getVScaleRange(const Function *F, unsigned BitWidth);
 
+/// If \p LHS or \p RHS is `mul nuw V, V`, return the implied unsigned range for
+/// \p V: [0, 2^ceil(bitwidth(V)/2)).
+LLVM_ABI std::optional<ConstantRange>
+getRangeForNuwMulSquare(const Value *V, const Value *LHS, const Value *RHS);
+
 /// Determine the possible constant range of an integer or vector of integer
 /// value. This is intended as a cheap, non-recursive check.
 LLVM_ABI ConstantRange computeConstantRange(const Value *V, bool ForSigned,
diff --git a/llvm/lib/Analysis/LazyValueInfo.cpp b/llvm/lib/Analysis/LazyValueInfo.cpp
index df75999eb6080..462459bf56f6a 100644
--- a/llvm/lib/Analysis/LazyValueInfo.cpp
+++ b/llvm/lib/Analysis/LazyValueInfo.cpp
@@ -1353,6 +1353,9 @@ std::optional<ValueLatticeElement> LazyValueInfoImpl::getValueFromICmpCondition(
     return ValueLatticeElement::getOverdefined();
 
   unsigned BitWidth = Ty->getScalarSizeInBits();
+  if (auto Range = getRangeForNuwMulSquare(Val, LHS, RHS))
+    return ValueLatticeElement::getRange(*Range);
+
   APInt Offset(BitWidth, 0);
   if (matchICmpOperand(Offset, LHS, Val, EdgePred))
     return getValueFromSimpleICmpCondition(EdgePred, RHS, Offset, ICI,
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index 045cbab221ac3..93d8f9864a45c 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -119,6 +119,24 @@ static const Instruction *safeCxtI(const Value *V, const Instruction *CxtI) {
   return nullptr;
 }
 
+std::optional<ConstantRange> llvm::getRangeForNuwMulSquare(
+    const Value *V, const Value *LHS, const Value *RHS) {
+  if (!V->getType()->isIntegerTy())
+    return std::nullopt;
+
+  if (!match(LHS, m_NUWMul(m_Specific(V), m_Specific(V))) &&
+      !match(RHS, m_NUWMul(m_Specific(V), m_Specific(V))))
+    return std::nullopt;
+
+  unsigned BitWidth = V->getType()->getScalarSizeInBits();
+  unsigned LimitBits = (BitWidth + 1) / 2;
+  if (LimitBits >= BitWidth)
+    return std::nullopt;
+
+  APInt Upper = APInt::getOneBitSet(BitWidth, LimitBits);
+  return ConstantRange::getNonEmpty(APInt::getZero(BitWidth), Upper);
+}
+
 static bool getShuffleDemandedElts(const ShuffleVectorInst *Shuf,
                                    const APInt &DemandedElts,
                                    APInt &DemandedLHS, APInt &DemandedRHS) {
@@ -976,6 +994,9 @@ static void computeKnownBitsFromICmpCond(const Value *V, ICmpInst *Cmp,
   Value *LHS = Cmp->getOperand(0);
   Value *RHS = Cmp->getOperand(1);
 
+  if (auto Range = getRangeForNuwMulSquare(V, LHS, RHS))
+    Known = Known.unionWith(Range->toKnownBits());
+
   // Handle icmp pred (trunc V), C
   if (match(LHS, m_Trunc(m_Specific(V)))) {
     KnownBits DstKnown(LHS->getType()->getScalarSizeInBits());
@@ -10383,7 +10404,16 @@ ConstantRange llvm::computeConstantRange(const Value *V, bool ForSigned,
       Value *Arg = I->getArgOperand(0);
       ICmpInst *Cmp = dyn_cast<ICmpInst>(Arg);
       // Currently we just use information from comparisons.
-      if (!Cmp || Cmp->getOperand(0) != V)
+      if (!Cmp)
+        continue;
+
+      if (auto Range = getRangeForNuwMulSquare(V, Cmp->getOperand(0),
+                                               Cmp->getOperand(1))) {
+        CR = CR.intersectWith(*Range);
+        continue;
+      }
+
+      if (Cmp->getOperand(0) != V)
         continue;
       // TODO: Set "ForSigned" parameter via Cmp->isSigned()?
       ConstantRange RHS =
@@ -10514,6 +10544,14 @@ void llvm::findValuesAffectedByCondition(
         }
       }
 
+      auto AddNuwSquareOperand = [&AddAffected](Value *Op) {
+        Value *SquareOp = nullptr;
+        if (match(Op, m_NUWMul(m_Value(SquareOp), m_Deferred(SquareOp))))
+          AddAffected(SquareOp);
+      };
+      AddNuwSquareOperand(A);
+      AddNuwSquareOperand(B);
+
       if (HasRHSC && match(A, m_Intrinsic<Intrinsic::ctpop>(m_Value(X))))
         AddAffected(X);
     } else if (match(V, m_FCmp(Pred, m_Value(A), m_Value(B)))) {
diff --git a/llvm/test/Transforms/CorrelatedValuePropagation/mul-nuw-square.ll b/llvm/test/Transforms/CorrelatedValuePropagation/mul-nuw-square.ll
index f1bda61b42cc6..afec6387d7301 100644
--- a/llvm/test/Transforms/CorrelatedValuePropagation/mul-nuw-square.ll
+++ b/llvm/test/Transforms/CorrelatedValuePropagation/mul-nuw-square.ll
@@ -8,8 +8,7 @@ define i1 @assume_mul_nuw_square_i8(i8 %s) {
 ; CHECK-NEXT:    [[MUL:%.*]] = mul nuw i8 [[S:%.*]], [[S]]
 ; CHECK-NEXT:    [[COND:%.*]] = icmp ule i8 [[MUL]], 120
 ; CHECK-NEXT:    call void @llvm.assume(i1 [[COND]])
-; CHECK-NEXT:    [[CMP:%.*]] = icmp ult i8 [[S]], 16
-; CHECK-NEXT:    ret i1 [[CMP]]
+; CHECK-NEXT:    ret i1 true
 ;
   %mul = mul nuw i8 %s, %s
   %cond = icmp ule i8 %mul, 120
@@ -23,8 +22,7 @@ define i1 @assume_mul_nuw_square_i5(i5 %s) {
 ; CHECK-NEXT:    [[MUL:%.*]] = mul nuw i5 [[S:%.*]], [[S]]
 ; CHECK-NEXT:    [[COND:%.*]] = icmp ult i5 [[MUL]], 15
 ; CHECK-NEXT:    call void @llvm.assume(i1 [[COND]])
-; CHECK-NEXT:    [[CMP:%.*]] = icmp ult i5 [[S]], 8
-; CHECK-NEXT:    ret i1 [[CMP]]
+; CHECK-NEXT:    ret i1 true
 ;
   %mul = mul nuw i5 %s, %s
   %cond = icmp ult i5 %mul, 15
@@ -39,11 +37,9 @@ define i1 @branch_mul_nuw_square(i8 %s, i8 %num) {
 ; CHECK-NEXT:    [[COND:%.*]] = icmp ule i8 [[MUL]], [[NUM:%.*]]
 ; CHECK-NEXT:    br i1 [[COND]], label [[TRUE:%.*]], label [[FALSE:%.*]]
 ; CHECK:       true:
-; CHECK-NEXT:    [[CMP:%.*]] = icmp ult i8 [[S]], 16
-; CHECK-NEXT:    ret i1 [[CMP]]
+; CHECK-NEXT:    ret i1 true
 ; CHECK:       false:
-; CHECK-NEXT:    [[CMP2:%.*]] = icmp ult i8 [[S]], 16
-; CHECK-NEXT:    ret i1 [[CMP2]]
+; CHECK-NEXT:    ret i1 true
 ;
   %mul = mul nuw i8 %s, %s
   %cond = icmp ule i8 %mul, %num



More information about the llvm-commits mailing list