[llvm] [PredicateInfo] Infer operand bound from mul nuw square predicates (PR #173127)
Ken Matsui via llvm-commits
llvm-commits at lists.llvm.org
Sat Dec 20 21:08:29 PST 2025
https://github.com/ken-matsui updated https://github.com/llvm/llvm-project/pull/173127
>From 00063b7e34b4602fcced54df62d70092d3d82537 Mon Sep 17 00:00:00 2001
From: Ken Matsui <github at kmts.me>
Date: Wed, 17 Dec 2025 19:21:16 -0500
Subject: [PATCH 1/2] Add baseline tests for upcoming patch
---
.../mul-nuw-square.ll | 91 +++++++++++++++++++
1 file changed, 91 insertions(+)
create mode 100644 llvm/test/Transforms/CorrelatedValuePropagation/mul-nuw-square.ll
diff --git a/llvm/test/Transforms/CorrelatedValuePropagation/mul-nuw-square.ll b/llvm/test/Transforms/CorrelatedValuePropagation/mul-nuw-square.ll
new file mode 100644
index 0000000000000..f1bda61b42cc6
--- /dev/null
+++ b/llvm/test/Transforms/CorrelatedValuePropagation/mul-nuw-square.ll
@@ -0,0 +1,91 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt -passes=correlated-propagation -S < %s | FileCheck %s
+
+declare void @llvm.assume(i1)
+
+define i1 @assume_mul_nuw_square_i8(i8 %s) {
+; CHECK-LABEL: @assume_mul_nuw_square_i8(
+; CHECK-NEXT: [[MUL:%.*]] = mul nuw i8 [[S:%.*]], [[S]]
+; CHECK-NEXT: [[COND:%.*]] = icmp ule i8 [[MUL]], 120
+; CHECK-NEXT: call void @llvm.assume(i1 [[COND]])
+; CHECK-NEXT: [[CMP:%.*]] = icmp ult i8 [[S]], 16
+; CHECK-NEXT: ret i1 [[CMP]]
+;
+ %mul = mul nuw i8 %s, %s
+ %cond = icmp ule i8 %mul, 120
+ call void @llvm.assume(i1 %cond)
+ %cmp = icmp ult i8 %s, 16
+ ret i1 %cmp
+}
+
+define i1 @assume_mul_nuw_square_i5(i5 %s) {
+; CHECK-LABEL: @assume_mul_nuw_square_i5(
+; CHECK-NEXT: [[MUL:%.*]] = mul nuw i5 [[S:%.*]], [[S]]
+; CHECK-NEXT: [[COND:%.*]] = icmp ult i5 [[MUL]], 15
+; CHECK-NEXT: call void @llvm.assume(i1 [[COND]])
+; CHECK-NEXT: [[CMP:%.*]] = icmp ult i5 [[S]], 8
+; CHECK-NEXT: ret i1 [[CMP]]
+;
+ %mul = mul nuw i5 %s, %s
+ %cond = icmp ult i5 %mul, 15
+ call void @llvm.assume(i1 %cond)
+ %cmp = icmp ult i5 %s, 8
+ ret i1 %cmp
+}
+
+define i1 @branch_mul_nuw_square(i8 %s, i8 %num) {
+; CHECK-LABEL: @branch_mul_nuw_square(
+; CHECK-NEXT: [[MUL:%.*]] = mul nuw i8 [[S:%.*]], [[S]]
+; CHECK-NEXT: [[COND:%.*]] = icmp ule i8 [[MUL]], [[NUM:%.*]]
+; CHECK-NEXT: br i1 [[COND]], label [[TRUE:%.*]], label [[FALSE:%.*]]
+; CHECK: true:
+; CHECK-NEXT: [[CMP:%.*]] = icmp ult i8 [[S]], 16
+; CHECK-NEXT: ret i1 [[CMP]]
+; CHECK: false:
+; CHECK-NEXT: [[CMP2:%.*]] = icmp ult i8 [[S]], 16
+; CHECK-NEXT: ret i1 [[CMP2]]
+;
+ %mul = mul nuw i8 %s, %s
+ %cond = icmp ule i8 %mul, %num
+ br i1 %cond, label %true, label %false
+
+true:
+ %cmp = icmp ult i8 %s, 16
+ ret i1 %cmp
+
+false:
+ %cmp2 = icmp ult i8 %s, 16
+ ret i1 %cmp2
+}
+
+; negative test: missing nuw on the multiply.
+define i1 @assume_mul_square_no_nuw(i8 %s) {
+; CHECK-LABEL: @assume_mul_square_no_nuw(
+; CHECK-NEXT: [[MUL:%.*]] = mul i8 [[S:%.*]], [[S]]
+; CHECK-NEXT: [[COND:%.*]] = icmp ule i8 [[MUL]], 120
+; CHECK-NEXT: call void @llvm.assume(i1 [[COND]])
+; CHECK-NEXT: [[CMP:%.*]] = icmp ult i8 [[S]], 16
+; CHECK-NEXT: ret i1 [[CMP]]
+;
+ %mul = mul i8 %s, %s
+ %cond = icmp ule i8 %mul, 120
+ call void @llvm.assume(i1 %cond)
+ %cmp = icmp ult i8 %s, 16
+ ret i1 %cmp
+}
+
+; negative test: multiply is not a square.
+define i1 @assume_mul_nuw_not_square(i8 %s, i8 %t) {
+; CHECK-LABEL: @assume_mul_nuw_not_square(
+; CHECK-NEXT: [[MUL:%.*]] = mul nuw i8 [[S:%.*]], [[T:%.*]]
+; CHECK-NEXT: [[COND:%.*]] = icmp ule i8 [[MUL]], 120
+; CHECK-NEXT: call void @llvm.assume(i1 [[COND]])
+; CHECK-NEXT: [[CMP:%.*]] = icmp ult i8 [[S]], 16
+; CHECK-NEXT: ret i1 [[CMP]]
+;
+ %mul = mul nuw i8 %s, %t
+ %cond = icmp ule i8 %mul, 120
+ call void @llvm.assume(i1 %cond)
+ %cmp = icmp ult i8 %s, 16
+ ret i1 %cmp
+}
>From b11a07e0a3b5b66e22869f80b3226201b513471b Mon Sep 17 00:00:00 2001
From: Ken Matsui <github at kmts.me>
Date: Sun, 21 Dec 2025 00:07:45 -0500
Subject: [PATCH 2/2] [ValueTracking] Infer operand bound from mul nuw square
predicates
A mul nuw X, X used in an assume/branch condition cannot overflow (or
the condition would be poison, which is UB for assumes and control
flow), which implies:
X < 2^ceil(bitwidth(X)/2) (e.g., i16: X < 256).
---
llvm/include/llvm/Analysis/ValueTracking.h | 5 +++
llvm/lib/Analysis/LazyValueInfo.cpp | 3 ++
llvm/lib/Analysis/ValueTracking.cpp | 40 ++++++++++++++++++-
.../mul-nuw-square.ll | 12 ++----
4 files changed, 51 insertions(+), 9 deletions(-)
diff --git a/llvm/include/llvm/Analysis/ValueTracking.h b/llvm/include/llvm/Analysis/ValueTracking.h
index b730a36488780..0c87fa2522b4b 100644
--- a/llvm/include/llvm/Analysis/ValueTracking.h
+++ b/llvm/include/llvm/Analysis/ValueTracking.h
@@ -663,6 +663,11 @@ LLVM_ABI bool isOverflowIntrinsicNoWrap(const WithOverflowInst *WO,
/// based on the vscale_range function attribute.
LLVM_ABI ConstantRange getVScaleRange(const Function *F, unsigned BitWidth);
+/// If \p LHS or \p RHS is `mul nuw V, V`, return the implied unsigned range for
+/// \p V: [0, 2^ceil(bitwidth(V)/2)).
+LLVM_ABI std::optional<ConstantRange>
+getRangeForNuwMulSquare(const Value *V, const Value *LHS, const Value *RHS);
+
/// Determine the possible constant range of an integer or vector of integer
/// value. This is intended as a cheap, non-recursive check.
LLVM_ABI ConstantRange computeConstantRange(const Value *V, bool ForSigned,
diff --git a/llvm/lib/Analysis/LazyValueInfo.cpp b/llvm/lib/Analysis/LazyValueInfo.cpp
index df75999eb6080..462459bf56f6a 100644
--- a/llvm/lib/Analysis/LazyValueInfo.cpp
+++ b/llvm/lib/Analysis/LazyValueInfo.cpp
@@ -1353,6 +1353,9 @@ std::optional<ValueLatticeElement> LazyValueInfoImpl::getValueFromICmpCondition(
return ValueLatticeElement::getOverdefined();
unsigned BitWidth = Ty->getScalarSizeInBits();
+ if (auto Range = getRangeForNuwMulSquare(Val, LHS, RHS))
+ return ValueLatticeElement::getRange(*Range);
+
APInt Offset(BitWidth, 0);
if (matchICmpOperand(Offset, LHS, Val, EdgePred))
return getValueFromSimpleICmpCondition(EdgePred, RHS, Offset, ICI,
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index 045cbab221ac3..93d8f9864a45c 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -119,6 +119,24 @@ static const Instruction *safeCxtI(const Value *V, const Instruction *CxtI) {
return nullptr;
}
+std::optional<ConstantRange> llvm::getRangeForNuwMulSquare(
+ const Value *V, const Value *LHS, const Value *RHS) {
+ if (!V->getType()->isIntegerTy())
+ return std::nullopt;
+
+ if (!match(LHS, m_NUWMul(m_Specific(V), m_Specific(V))) &&
+ !match(RHS, m_NUWMul(m_Specific(V), m_Specific(V))))
+ return std::nullopt;
+
+ unsigned BitWidth = V->getType()->getScalarSizeInBits();
+ unsigned LimitBits = (BitWidth + 1) / 2;
+ if (LimitBits >= BitWidth)
+ return std::nullopt;
+
+ APInt Upper = APInt::getOneBitSet(BitWidth, LimitBits);
+ return ConstantRange::getNonEmpty(APInt::getZero(BitWidth), Upper);
+}
+
static bool getShuffleDemandedElts(const ShuffleVectorInst *Shuf,
const APInt &DemandedElts,
APInt &DemandedLHS, APInt &DemandedRHS) {
@@ -976,6 +994,9 @@ static void computeKnownBitsFromICmpCond(const Value *V, ICmpInst *Cmp,
Value *LHS = Cmp->getOperand(0);
Value *RHS = Cmp->getOperand(1);
+ if (auto Range = getRangeForNuwMulSquare(V, LHS, RHS))
+ Known = Known.unionWith(Range->toKnownBits());
+
// Handle icmp pred (trunc V), C
if (match(LHS, m_Trunc(m_Specific(V)))) {
KnownBits DstKnown(LHS->getType()->getScalarSizeInBits());
@@ -10383,7 +10404,16 @@ ConstantRange llvm::computeConstantRange(const Value *V, bool ForSigned,
Value *Arg = I->getArgOperand(0);
ICmpInst *Cmp = dyn_cast<ICmpInst>(Arg);
// Currently we just use information from comparisons.
- if (!Cmp || Cmp->getOperand(0) != V)
+ if (!Cmp)
+ continue;
+
+ if (auto Range = getRangeForNuwMulSquare(V, Cmp->getOperand(0),
+ Cmp->getOperand(1))) {
+ CR = CR.intersectWith(*Range);
+ continue;
+ }
+
+ if (Cmp->getOperand(0) != V)
continue;
// TODO: Set "ForSigned" parameter via Cmp->isSigned()?
ConstantRange RHS =
@@ -10514,6 +10544,14 @@ void llvm::findValuesAffectedByCondition(
}
}
+ auto AddNuwSquareOperand = [&AddAffected](Value *Op) {
+ Value *SquareOp = nullptr;
+ if (match(Op, m_NUWMul(m_Value(SquareOp), m_Deferred(SquareOp))))
+ AddAffected(SquareOp);
+ };
+ AddNuwSquareOperand(A);
+ AddNuwSquareOperand(B);
+
if (HasRHSC && match(A, m_Intrinsic<Intrinsic::ctpop>(m_Value(X))))
AddAffected(X);
} else if (match(V, m_FCmp(Pred, m_Value(A), m_Value(B)))) {
diff --git a/llvm/test/Transforms/CorrelatedValuePropagation/mul-nuw-square.ll b/llvm/test/Transforms/CorrelatedValuePropagation/mul-nuw-square.ll
index f1bda61b42cc6..afec6387d7301 100644
--- a/llvm/test/Transforms/CorrelatedValuePropagation/mul-nuw-square.ll
+++ b/llvm/test/Transforms/CorrelatedValuePropagation/mul-nuw-square.ll
@@ -8,8 +8,7 @@ define i1 @assume_mul_nuw_square_i8(i8 %s) {
; CHECK-NEXT: [[MUL:%.*]] = mul nuw i8 [[S:%.*]], [[S]]
; CHECK-NEXT: [[COND:%.*]] = icmp ule i8 [[MUL]], 120
; CHECK-NEXT: call void @llvm.assume(i1 [[COND]])
-; CHECK-NEXT: [[CMP:%.*]] = icmp ult i8 [[S]], 16
-; CHECK-NEXT: ret i1 [[CMP]]
+; CHECK-NEXT: ret i1 true
;
%mul = mul nuw i8 %s, %s
%cond = icmp ule i8 %mul, 120
@@ -23,8 +22,7 @@ define i1 @assume_mul_nuw_square_i5(i5 %s) {
; CHECK-NEXT: [[MUL:%.*]] = mul nuw i5 [[S:%.*]], [[S]]
; CHECK-NEXT: [[COND:%.*]] = icmp ult i5 [[MUL]], 15
; CHECK-NEXT: call void @llvm.assume(i1 [[COND]])
-; CHECK-NEXT: [[CMP:%.*]] = icmp ult i5 [[S]], 8
-; CHECK-NEXT: ret i1 [[CMP]]
+; CHECK-NEXT: ret i1 true
;
%mul = mul nuw i5 %s, %s
%cond = icmp ult i5 %mul, 15
@@ -39,11 +37,9 @@ define i1 @branch_mul_nuw_square(i8 %s, i8 %num) {
; CHECK-NEXT: [[COND:%.*]] = icmp ule i8 [[MUL]], [[NUM:%.*]]
; CHECK-NEXT: br i1 [[COND]], label [[TRUE:%.*]], label [[FALSE:%.*]]
; CHECK: true:
-; CHECK-NEXT: [[CMP:%.*]] = icmp ult i8 [[S]], 16
-; CHECK-NEXT: ret i1 [[CMP]]
+; CHECK-NEXT: ret i1 true
; CHECK: false:
-; CHECK-NEXT: [[CMP2:%.*]] = icmp ult i8 [[S]], 16
-; CHECK-NEXT: ret i1 [[CMP2]]
+; CHECK-NEXT: ret i1 true
;
%mul = mul nuw i8 %s, %s
%cond = icmp ule i8 %mul, %num
More information about the llvm-commits
mailing list