[llvm] [KnownBits] Refine known bits for lerp (PR #166378)
via llvm-commits
llvm-commits at lists.llvm.org
Tue Nov 4 06:46:41 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-analysis
@llvm/pr-subscribers-llvm-transforms
Author: Valeriy Savchenko (SavchenkoValeriy)
<details>
<summary>Changes</summary>
In this patch, we try to detect the lerp pattern: a * (b - c) + c * d
where a >= 0, b >= 0, c >= 0, d >= 0, and b >= c.
In that particular case, we can use the following chain of reasoning:
a * (b - c) + c * d <= a' * (b - c) + a' * c = a' * b where a' = max(a, d)
Since that is true for arbitrary a, b, c and d within our constraints, we can
conclude that:
max(a * (b - c) + c * d) <= max(max(a), max(d)) * max(b) = U
Considering that any result of the lerp would be less or equal to U, it would
have at least the number of leading 0s as in U.
While being quite a specific situation, it is fairly common in computer
graphics in the shape of alpha blending.
In conjunction with #<!-- -->165877, increases vectorization factor for lerp loops.
---
Full diff: https://github.com/llvm/llvm-project/pull/166378.diff
2 Files Affected:
- (modified) llvm/lib/Analysis/ValueTracking.cpp (+143)
- (added) llvm/test/Transforms/InstCombine/known-bits-lerp-pattern.ll (+156)
``````````diff
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index 0a72076f51824..4c74710065371 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -350,6 +350,140 @@ unsigned llvm::ComputeMaxSignificantBits(const Value *V, const DataLayout &DL,
return V->getType()->getScalarSizeInBits() - SignBits + 1;
}
+// Try to detect the lerp pattern: a * (b - c) + c * d
+// where a >= 0, b >= 0, c >= 0, d >= 0, and b >= c.
+//
+// In that particular case, we can use the following chain of reasoning:
+//
+// a * (b - c) + c * d <= a' * (b - c) + a' * c = a' * b where a' = max(a, d)
+//
+// Since that is true for arbitrary a, b, c and d within our constraints, we can
+// conclude that:
+//
+// max(a * (b - c) + c * d) <= max(max(a), max(d)) * max(b) = U
+//
+// Considering that any result of the lerp would be less or equal to U, it would
+// have at least the number of leading 0s as in U.
+//
+// While being quite a specific situation, it is fairly common in computer
+// graphics in the shape of alpha blending.
+//
+// Returns unknown bits if the pattern doesn't match or constraints don't apply
+// to the given operands.
+static KnownBits computeKnownBitsFromLerpPattern(const Value *Op0,
+ const Value *Op1,
+ const APInt &DemandedElts,
+ const SimplifyQuery &Q,
+ unsigned Depth) {
+
+ Type *Ty = Op0->getType();
+ const unsigned BitWidth = Ty->getScalarSizeInBits();
+
+ KnownBits Result(BitWidth);
+
+ // Only handle scalar types for now
+ if (Ty->isVectorTy())
+ return Result;
+
+ // Try to match: a * (b - c) + c * d.
+ // When a == 1 => A == nullptr, the same applies to d/D as well.
+ const Value *A = nullptr, *B = nullptr, *C = nullptr, *D = nullptr;
+
+ const auto MatchSubBC = [&]() {
+ // (b - c) can have two forms that interest us:
+ //
+ // 1. sub nuw %b, %c
+ // 2. xor %c, %b
+ //
+ // For the first case, nuw flag guarantees our requirement b >= c.
+ //
+ // The second case happens when the analysis can infer that b is a mask for
+ // c and we can transform sub operation into xor (that is usually true for
+ // constant b's). Even though xor is symmetrical, canonicalization ensures
+ // that the constant will be the RHS. xor of two positive integers is
+ // guaranteed to be non-negative as well.
+ return m_CombineOr(m_NUWSub(m_Value(B), m_Value(C)),
+ m_Xor(m_Value(C), m_Value(B)));
+ };
+
+ const auto MatchASubBC = [&]() {
+ // Cases:
+ // - a * (b - c)
+ // - (b - c) * a
+ // - (b - c) <- a implicitly equals 1
+ return m_CombineOr(m_CombineOr(m_Mul(m_Value(A), MatchSubBC()),
+ m_Mul(MatchSubBC(), m_Value(A))),
+ MatchSubBC());
+ };
+
+ const auto MatchCD = [&]() {
+ // Cases:
+ // - d * c
+ // - c * d
+ // - c <- d implicitly equals 1
+ return m_CombineOr(m_CombineOr(m_Mul(m_Value(D), m_Specific(C)),
+ m_Mul(m_Specific(C), m_Value(D))),
+ m_Specific(C));
+ };
+
+ const auto Match = [&](const Value *LHS, const Value *RHS) {
+ // We do use m_Specific(C) in MatchCD, so we have to make sure that
+ // it's bound to anything and match(LHS, MatchASubBC()) absolutely
+ // has to evaluate first and return true.
+ //
+ // If Match returns true, it is guaranteed that B != nullptr, C != nullptr.
+ return match(LHS, MatchASubBC()) && match(RHS, MatchCD());
+ };
+
+ if (!Match(Op0, Op1) && !Match(Op1, Op0))
+ return Result;
+
+ const auto ComputeKnownBitsOrOne = [&](const Value *V) {
+ // For some of the values we use the convention of leaving
+ // it nullptr to signify an implicit constant 1.
+ return V ? computeKnownBits(V, DemandedElts, Q, Depth + 1)
+ : KnownBits::makeConstant(APInt(BitWidth, 1));
+ };
+
+ // Check that all operands are non-negative
+ const KnownBits KnownA = ComputeKnownBitsOrOne(A);
+ if (!KnownA.isNonNegative())
+ return Result;
+
+ const KnownBits KnownD = ComputeKnownBitsOrOne(D);
+ if (!KnownD.isNonNegative())
+ return Result;
+
+ const KnownBits KnownB = computeKnownBits(B, DemandedElts, Q, Depth + 1);
+ if (!KnownB.isNonNegative())
+ return Result;
+
+ const KnownBits KnownC = computeKnownBits(C, DemandedElts, Q, Depth + 1);
+ if (!KnownC.isNonNegative())
+ return Result;
+
+ // Compute max(a, d)
+ const APInt MaxA = KnownA.getMaxValue();
+ const APInt MaxD = KnownD.getMaxValue();
+ const APInt MaxAD = MaxA.ult(MaxD) ? MaxD : MaxA;
+
+ // Compute max(a, d) * max(b)
+ const APInt MaxB = KnownB.getMaxValue();
+ bool Overflow;
+ const APInt UpperBound = MaxAD.umul_ov(MaxB, Overflow);
+
+ if (Overflow)
+ return Result;
+
+ // Count leading zeros in upper bound
+ const unsigned MinimumNumberOfLeadingZeros = UpperBound.countl_zero();
+
+ // Create KnownBits with only leading zeros set
+ Result.Zero.setHighBits(MinimumNumberOfLeadingZeros);
+
+ return Result;
+}
+
static void computeKnownBitsAddSub(bool Add, const Value *Op0, const Value *Op1,
bool NSW, bool NUW,
const APInt &DemandedElts,
@@ -369,6 +503,15 @@ static void computeKnownBitsAddSub(bool Add, const Value *Op0, const Value *Op1,
isImpliedByDomCondition(ICmpInst::ICMP_SLE, Op1, Op0, Q.CxtI, Q.DL)
.value_or(false))
KnownOut.makeNonNegative();
+
+ if (Add) {
+ // Try to match lerp pattern and combine results
+ const KnownBits LerpKnown =
+ computeKnownBitsFromLerpPattern(Op0, Op1, DemandedElts, Q, Depth);
+ // Union of any two conservative estimates results in a conservative
+ // estimate that is at least as precise as each individual estimate.
+ KnownOut = KnownOut.unionWith(LerpKnown);
+ }
}
static void computeKnownBitsMul(const Value *Op0, const Value *Op1, bool NSW,
diff --git a/llvm/test/Transforms/InstCombine/known-bits-lerp-pattern.ll b/llvm/test/Transforms/InstCombine/known-bits-lerp-pattern.ll
new file mode 100644
index 0000000000000..3018d3e99f636
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/known-bits-lerp-pattern.ll
@@ -0,0 +1,156 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt < %s -passes=instcombine -S | FileCheck %s
+
+; Test known bits refinements for pattern: a * (b - c) + c * d
+; where a > 0, c > 0, b > 0, d > 0, and b > c.
+; This pattern is a generalization of lerp and it appears frequently in graphics operations.
+
+define i32 @test_clamp(i8 %a, i8 %c, i8 %d) {
+; CHECK-LABEL: define i32 @test_clamp(
+; CHECK-SAME: i8 [[A:%.*]], i8 [[C:%.*]], i8 [[D:%.*]]) {
+; CHECK-NEXT: [[A32:%.*]] = zext i8 [[A]] to i32
+; CHECK-NEXT: [[C32:%.*]] = zext i8 [[C]] to i32
+; CHECK-NEXT: [[D32:%.*]] = zext i8 [[D]] to i32
+; CHECK-NEXT: [[SUB:%.*]] = xor i32 [[C32]], 255
+; CHECK-NEXT: [[MUL1:%.*]] = mul nuw nsw i32 [[SUB]], [[A32]]
+; CHECK-NEXT: [[MUL2:%.*]] = mul nuw nsw i32 [[C32]], [[D32]]
+; CHECK-NEXT: [[ADD:%.*]] = add nuw nsw i32 [[MUL1]], [[MUL2]]
+; CHECK-NEXT: ret i32 [[ADD]]
+;
+ %a32 = zext i8 %a to i32
+ %c32 = zext i8 %c to i32
+ %d32 = zext i8 %d to i32
+ %sub = sub i32 255, %c32
+ %mul1 = mul i32 %a32, %sub
+ %mul2 = mul i32 %c32, %d32
+ %add = add i32 %mul1, %mul2
+ %cmp = icmp ugt i32 %add, 65535
+ %result = select i1 %cmp, i32 65535, i32 %add
+ ret i32 %result
+}
+
+define i1 @test_trunc_cmp(i8 %a, i8 %c, i8 %d) {
+; CHECK-LABEL: define i1 @test_trunc_cmp(
+; CHECK-SAME: i8 [[A:%.*]], i8 [[C:%.*]], i8 [[D:%.*]]) {
+; CHECK-NEXT: [[A32:%.*]] = zext i8 [[A]] to i32
+; CHECK-NEXT: [[C32:%.*]] = zext i8 [[C]] to i32
+; CHECK-NEXT: [[D32:%.*]] = zext i8 [[D]] to i32
+; CHECK-NEXT: [[SUB:%.*]] = xor i32 [[C32]], 255
+; CHECK-NEXT: [[MUL1:%.*]] = mul nuw nsw i32 [[SUB]], [[A32]]
+; CHECK-NEXT: [[MUL2:%.*]] = mul nuw nsw i32 [[C32]], [[D32]]
+; CHECK-NEXT: [[ADD:%.*]] = add nuw nsw i32 [[MUL1]], [[MUL2]]
+; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[ADD]], 1234
+; CHECK-NEXT: ret i1 [[CMP]]
+;
+ %a32 = zext i8 %a to i32
+ %c32 = zext i8 %c to i32
+ %d32 = zext i8 %d to i32
+ %sub = sub i32 255, %c32
+ %mul1 = mul i32 %a32, %sub
+ %mul2 = mul i32 %c32, %d32
+ %add = add i32 %mul1, %mul2
+ %trunc = trunc i32 %add to i16
+ %cmp = icmp eq i16 %trunc, 1234
+ ret i1 %cmp
+}
+
+define i1 @test_trunc_cmp_xor(i8 %a, i8 %c, i8 %d) {
+; CHECK-LABEL: define i1 @test_trunc_cmp_xor(
+; CHECK-SAME: i8 [[A:%.*]], i8 [[C:%.*]], i8 [[D:%.*]]) {
+; CHECK-NEXT: [[A32:%.*]] = zext i8 [[A]] to i32
+; CHECK-NEXT: [[C32:%.*]] = zext i8 [[C]] to i32
+; CHECK-NEXT: [[D32:%.*]] = zext i8 [[D]] to i32
+; CHECK-NEXT: [[SUB:%.*]] = xor i32 [[C32]], 255
+; CHECK-NEXT: [[MUL1:%.*]] = mul nuw nsw i32 [[SUB]], [[A32]]
+; CHECK-NEXT: [[MUL2:%.*]] = mul nuw nsw i32 [[C32]], [[D32]]
+; CHECK-NEXT: [[ADD:%.*]] = add nuw nsw i32 [[MUL1]], [[MUL2]]
+; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[ADD]], 1234
+; CHECK-NEXT: ret i1 [[CMP]]
+;
+ %a32 = zext i8 %a to i32
+ %c32 = zext i8 %c to i32
+ %d32 = zext i8 %d to i32
+ %sub = xor i32 255, %c32
+ %mul1 = mul i32 %a32, %sub
+ %mul2 = mul i32 %c32, %d32
+ %add = add i32 %mul1, %mul2
+ %trunc = trunc i32 %add to i16
+ %cmp = icmp eq i16 %trunc, 1234
+ ret i1 %cmp
+}
+
+define i1 @test_trunc_cmp_arbitrary_b(i8 %a, i8 %b, i8 %c, i8 %d) {
+; CHECK-LABEL: define i1 @test_trunc_cmp_arbitrary_b(
+; CHECK-SAME: i8 [[A:%.*]], i8 [[B:%.*]], i8 [[C:%.*]], i8 [[D:%.*]]) {
+; CHECK-NEXT: [[A32:%.*]] = zext i8 [[A]] to i32
+; CHECK-NEXT: [[B32:%.*]] = zext i8 [[B]] to i32
+; CHECK-NEXT: [[C32:%.*]] = zext i8 [[C]] to i32
+; CHECK-NEXT: [[D32:%.*]] = zext i8 [[D]] to i32
+; CHECK-NEXT: [[SUB:%.*]] = sub nuw nsw i32 [[B32]], [[C32]]
+; CHECK-NEXT: [[MUL1:%.*]] = mul nuw nsw i32 [[SUB]], [[A32]]
+; CHECK-NEXT: [[MUL2:%.*]] = mul nuw nsw i32 [[C32]], [[D32]]
+; CHECK-NEXT: [[ADD:%.*]] = add nuw nsw i32 [[MUL1]], [[MUL2]]
+; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[ADD]], 1234
+; CHECK-NEXT: ret i1 [[CMP]]
+;
+ %a32 = zext i8 %a to i32
+ %b32 = zext i8 %b to i32
+ %c32 = zext i8 %c to i32
+ %d32 = zext i8 %d to i32
+ %sub = sub nsw nuw i32 %b32, %c32
+ %mul1 = mul i32 %a32, %sub
+ %mul2 = mul i32 %c32, %d32
+ %add = add i32 %mul1, %mul2
+ %trunc = trunc i32 %add to i16
+ %cmp = icmp eq i16 %trunc, 1234
+ ret i1 %cmp
+}
+
+
+define i1 @test_trunc_cmp_no_a(i8 %b, i8 %c, i8 %d) {
+; CHECK-LABEL: define i1 @test_trunc_cmp_no_a(
+; CHECK-SAME: i8 [[B:%.*]], i8 [[C:%.*]], i8 [[D:%.*]]) {
+; CHECK-NEXT: [[B32:%.*]] = zext i8 [[B]] to i32
+; CHECK-NEXT: [[C32:%.*]] = zext i8 [[C]] to i32
+; CHECK-NEXT: [[D32:%.*]] = zext i8 [[D]] to i32
+; CHECK-NEXT: [[MUL1:%.*]] = sub nuw nsw i32 [[B32]], [[C32]]
+; CHECK-NEXT: [[MUL2:%.*]] = mul nuw nsw i32 [[C32]], [[D32]]
+; CHECK-NEXT: [[ADD:%.*]] = add nuw nsw i32 [[MUL1]], [[MUL2]]
+; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[ADD]], 1234
+; CHECK-NEXT: ret i1 [[CMP]]
+;
+ %b32 = zext i8 %b to i32
+ %c32 = zext i8 %c to i32
+ %d32 = zext i8 %d to i32
+ %sub = sub nuw i32 %b32, %c32
+ %mul2 = mul i32 %c32, %d32
+ %add = add i32 %sub, %mul2
+ %trunc = trunc i32 %add to i16
+ %cmp = icmp eq i16 %trunc, 1234
+ ret i1 %cmp
+}
+
+define i1 @test_trunc_cmp_no_d(i8 %a, i8 %b, i8 %c) {
+; CHECK-LABEL: define i1 @test_trunc_cmp_no_d(
+; CHECK-SAME: i8 [[A:%.*]], i8 [[B:%.*]], i8 [[C:%.*]]) {
+; CHECK-NEXT: [[A32:%.*]] = zext i8 [[A]] to i32
+; CHECK-NEXT: [[B32:%.*]] = zext i8 [[B]] to i32
+; CHECK-NEXT: [[C32:%.*]] = zext i8 [[C]] to i32
+; CHECK-NEXT: [[SUB:%.*]] = sub nuw nsw i32 [[B32]], [[C32]]
+; CHECK-NEXT: [[MUL1:%.*]] = mul nuw nsw i32 [[SUB]], [[A32]]
+; CHECK-NEXT: [[ADD:%.*]] = add nuw nsw i32 [[MUL1]], [[C32]]
+; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[ADD]], 1234
+; CHECK-NEXT: ret i1 [[CMP]]
+;
+ %a32 = zext i8 %a to i32
+ %b32 = zext i8 %b to i32
+ %c32 = zext i8 %c to i32
+ %sub = sub nsw nuw i32 %b32, %c32
+ %mul1 = mul i32 %a32, %sub
+ %add = add i32 %mul1, %c32
+ %trunc = trunc i32 %add to i16
+ %cmp = icmp eq i16 %trunc, 1234
+ ret i1 %cmp
+}
+
+declare void @llvm.assume(i1)
``````````
</details>
https://github.com/llvm/llvm-project/pull/166378
More information about the llvm-commits
mailing list