[llvm] [InstCombine] Fold out-of-range bits for squaring signed integers (PR #153484)

Nikita Popov via llvm-commits llvm-commits at lists.llvm.org
Tue Aug 19 05:25:10 PDT 2025


================
@@ -423,6 +423,48 @@ static void computeKnownBitsMul(const Value *Op0, const Value *Op1, bool NSW,
     Known.makeNonNegative();
   else if (isKnownNegative && !Known.isNonNegative())
     Known.makeNegative();
+
+  // Additional logic: If both operands are the same sign- or zero-extended
+  // value from a small integer, and the multiplication is (sext x) * (sext x)
+  // or (zext x) * (zext x), then the result cannot set bits above the maximum
+  // possible square. This allows InstCombine and other passes to fold (x * x) &
+  // (1 << N) to 0 when N is out of range.
+  const Value *A = nullptr;
+  // Only handle the case where both operands are the same extension of the same
+  // value.
+  if ((match(Op0, m_SExt(m_Value(A))) && match(Op1, m_SExt(m_Specific(A)))) ||
+      (match(Op0, m_ZExt(m_Value(A))) && match(Op1, m_ZExt(m_Specific(A))))) {
----------------
nikic wrote:

You can use the same logic as ComputeNumSignBits: https://github.com/llvm/llvm-project/blob/92a91f71ee217b71a0655338dc063d557fbe33c0/llvm/lib/Analysis/ValueTracking.cpp#L4278-L4280

Adjusted for the case where the sign bits are the same for both operands:
```
      unsigned OutValidBits = 2 * (TyBits - SignBits + 1);
      unsigned OutSignBits = OutValidBits > TyBits ? 1 : TyBits - OutValidBits + 1;
```

https://github.com/llvm/llvm-project/pull/153484


More information about the llvm-commits mailing list