[llvm] [ValueTracking] Refine known bits for linear interpolation patterns (PR #166378)

Yingwei Zheng via llvm-commits llvm-commits at lists.llvm.org
Thu Nov 6 04:30:04 PST 2025


================
@@ -350,6 +350,152 @@ unsigned llvm::ComputeMaxSignificantBits(const Value *V, const DataLayout &DL,
   return V->getType()->getScalarSizeInBits() - SignBits + 1;
 }
 
+/// Try to detect the lerp pattern: a * (b - c) + c * d
+/// where a >= 0, b >= 0, c >= 0, d >= 0, and b >= c.
+///
+/// In that particular case, we can use the following chain of reasoning:
+///
+///   a * (b - c) + c * d <= a' * (b - c) + a' * c = a' * b where a' = max(a, d)
+///
+/// Since that is true for arbitrary a, b, c and d within our constraints, we
+/// can conclude that:
+///
+///   max(a * (b - c) + c * d) <= max(max(a), max(d)) * max(b) = U
+///
+/// Considering that any result of the lerp would be less or equal to U, it
+/// would have at least the number of leading 0s as in U.
+///
+/// While being quite a specific situation, it is fairly common in computer
+/// graphics in the shape of alpha blending.
+///
+/// Modifies given KnownOut in-place with the inferred information.
+static void computeKnownBitsFromLerpPattern(const Value *Op0, const Value *Op1,
+                                            const APInt &DemandedElts,
+                                            KnownBits &KnownOut,
+                                            const SimplifyQuery &Q,
+                                            unsigned Depth) {
+
+  Type *Ty = Op0->getType();
+  const unsigned BitWidth = Ty->getScalarSizeInBits();
+
+  // Only handle scalar types for now
+  if (Ty->isVectorTy())
+    return;
+
+  // Try to match: a * (b - c) + c * d.
+  // When a == 1 => A == nullptr, the same applies to d/D as well.
+  const Value *A = nullptr, *B = nullptr, *C = nullptr, *D = nullptr;
+  const BinaryOperator *SubBC = nullptr;
+
+  const auto MatchSubBC = [&]() {
+    // (b - c) can have two forms that interest us:
+    //
+    //   1. sub nuw %b, %c
+    //   2. xor %c, %b
+    //
+    // For the first case, nuw flag guarantees our requirement b >= c.
+    //
+    // The second case might happen when the analysis can infer that b is a mask
+    // for c and we can transform sub operation into xor (that is usually true
+    // for constant b's). Even though xor is symmetrical, canonicalization
+    // ensures that the constant will be the RHS. We have additional checks
+    // later on to ensure that this xor operation is equivalent to subtraction.
+    return m_CombineAnd(m_BinOp(SubBC),
+                        m_CombineOr(m_NUWSub(m_Value(B), m_Value(C)),
+                                    m_Xor(m_Value(C), m_Value(B))));
+  };
+
+  const auto MatchASubBC = [&]() {
+    // Cases:
+    //   - a * (b - c)
+    //   - (b - c) * a
+    //   - (b - c) <- a implicitly equals 1
+    return m_CombineOr(m_c_Mul(m_Value(A), MatchSubBC()), MatchSubBC());
+  };
+
+  const auto MatchCD = [&]() {
+    // Cases:
+    //   - d * c
+    //   - c * d
+    //   - c <- d implicitly equals 1
+    return m_CombineOr(m_c_Mul(m_Value(D), m_Specific(C)), m_Specific(C));
+  };
+
+  const auto Match = [&](const Value *LHS, const Value *RHS) {
+    // We do use m_Specific(C) in MatchCD, so we have to make sure that
+    // it's bound to anything and match(LHS, MatchASubBC()) absolutely
+    // has to evaluate first and return true.
+    //
+    // If Match returns true, it is guaranteed that B != nullptr, C != nullptr.
+    return match(LHS, MatchASubBC()) && match(RHS, MatchCD());
+  };
+
+  if (!Match(Op0, Op1) && !Match(Op1, Op0))
+    return;
+
+  const auto ComputeKnownBitsOrOne = [&](const Value *V) {
+    // For some of the values we use the convention of leaving
+    // it nullptr to signify an implicit constant 1.
+    return V ? computeKnownBits(V, DemandedElts, Q, Depth + 1)
+             : KnownBits::makeConstant(APInt(BitWidth, 1));
+  };
+
+  // Check that all operands are non-negative
+  const KnownBits KnownA = ComputeKnownBitsOrOne(A);
+  if (!KnownA.isNonNegative())
+    return;
+
+  const KnownBits KnownD = ComputeKnownBitsOrOne(D);
+  if (!KnownD.isNonNegative())
+    return;
+
+  const KnownBits KnownB = computeKnownBits(B, DemandedElts, Q, Depth + 1);
+  if (!KnownB.isNonNegative())
+    return;
+
+  const KnownBits KnownC = computeKnownBits(C, DemandedElts, Q, Depth + 1);
+  if (!KnownC.isNonNegative())
+    return;
+
+  if (SubBC->getOpcode() == Instruction::Xor) {
+    // If we matched subtraction as xor, we need to actually check that xor
+    // is semantically equivalent to subtraction.
+    //
+    // For that to be true, b has to be a mask for c.
+    // In known bits terms it would mean the following:
+    //
+    //   - b is a constant
+    if (!KnownB.isConstant())
----------------
dtcxzyw wrote:

B can be a non-constant. We can simply check `KnownC.getMaxValue().isSubsetOf(KnownB.getMinValue())`.

https://github.com/llvm/llvm-project/pull/166378


More information about the llvm-commits mailing list