[llvm] [InstCombine][missed-optimizations] Fold out-of-range bits for squaring signed integers (PR #153484)
via llvm-commits
llvm-commits at lists.llvm.org
Wed Aug 13 13:15:04 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-analysis
Author: None (Aethezz)
<details>
<summary>Changes</summary>
Fixes an issue where bits next to the sign bit were not constant-folded when squaring a sign- or zero-extended small integer. Added logic to detect when both operands of a multiplication are the same extended value, allowing InstCombine to mark bits above the maximum possible square as known zero. This enables correct folding of (x * x) & (1 << N) to 0 when N is out of range.
Fixes #<!-- -->152061
---
Full diff: https://github.com/llvm/llvm-project/pull/153484.diff
2 Files Affected:
- (modified) llvm/lib/Analysis/ValueTracking.cpp (+43)
- (modified) llvm/test/Analysis/ValueTracking/known-bits.ll (+33)
``````````diff
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index af85ce4077ec8..7a973140f6075 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -423,6 +423,49 @@ static void computeKnownBitsMul(const Value *Op0, const Value *Op1, bool NSW,
Known.makeNonNegative();
else if (isKnownNegative && !Known.isNonNegative())
Known.makeNegative();
+
+ // Additional logic: If both operands are the same sign- or zero-extended
+ // value from a small integer, and the multiplication is (sext x) * (sext x)
+ // or (zext x) * (zext x), then the result cannot set bits above the maximum
+ // possible square. This allows InstCombine and other passes to fold (x * x) &
+ // (1 << N) to 0 when N is out of range.
+ using namespace PatternMatch;
+ const Value *A = nullptr;
+ // Only handle the case where both operands are the same extension of the same
+ // value.
+ if ((match(Op0, m_SExt(m_Value(A))) && match(Op1, m_SExt(m_Specific(A)))) ||
+ (match(Op0, m_ZExt(m_Value(A))) && match(Op1, m_ZExt(m_Specific(A))))) {
+ Type *FromTy = A->getType();
+ Type *ToTy = Op0->getType();
+ if (FromTy->isIntegerTy() && ToTy->isIntegerTy() &&
+ FromTy->getScalarSizeInBits() < ToTy->getScalarSizeInBits()) {
+ unsigned FromBits = FromTy->getScalarSizeInBits();
+ unsigned ToBits = ToTy->getScalarSizeInBits();
+ // For both signed and unsigned, the maximum absolute value is max(|min|,
+ // |max|)
+ APInt minVal(FromBits, 0), maxVal(FromBits, 0);
+ bool isSigned = isa<SExtInst>(Op0);
+ if (isSigned) {
+ minVal = APInt::getSignedMinValue(FromBits);
+ maxVal = APInt::getSignedMaxValue(FromBits);
+ } else {
+ minVal = APInt::getMinValue(FromBits);
+ maxVal = APInt::getMaxValue(FromBits);
+ }
+ APInt absMin = minVal.abs();
+ APInt absMax = maxVal.abs();
+ APInt maxAbs = absMin.ugt(absMax) ? absMin : absMax;
+ APInt maxSquare = maxAbs.zext(ToBits);
+ maxSquare = maxSquare * maxSquare;
+ // All bits above the highest set bit in maxSquare are known zero.
+ unsigned MaxBit = maxSquare.isZero() ? 0 : maxSquare.logBase2();
+ if (MaxBit + 1 < ToBits) {
+ APInt KnownZeroMask =
+ APInt::getHighBitsSet(ToBits, ToBits - (MaxBit + 1));
+ Known.Zero |= KnownZeroMask;
+ }
+ }
+ }
}
void llvm::computeKnownBitsFromRangeMetadata(const MDNode &Ranges,
diff --git a/llvm/test/Analysis/ValueTracking/known-bits.ll b/llvm/test/Analysis/ValueTracking/known-bits.ll
index 5b71402a96f0d..d9f119bd0d146 100644
--- a/llvm/test/Analysis/ValueTracking/known-bits.ll
+++ b/llvm/test/Analysis/ValueTracking/known-bits.ll
@@ -49,3 +49,36 @@ define i1 @vec_reverse_known_bits_demanded_fail(<4 x i8> %xx) {
%r = icmp slt i8 %ele, 0
ret i1 %r
}
+
+; Test known bits for (sext i8 x) * (sext i8 x)
+; RUN: opt -passes=instcombine < %s -S | FileCheck %s --check-prefix=SEXT_SQUARE
+
+define i1 @sext_square_bit31(i8 %x) {
+; SEXT_SQUARE-LABEL: @sext_square_bit31(
+; SEXT_SQUARE-NEXT: ret i1 false
+ %sx = sext i8 %x to i32
+ %mul = mul nsw i32 %sx, %sx
+ %and = and i32 %mul, 2147483648 ; 1 << 31
+ %cmp = icmp ne i32 %and, 0
+ ret i1 %cmp
+}
+
+define i1 @sext_square_bit30(i8 %x) {
+; SEXT_SQUARE-LABEL: @sext_square_bit30(
+; SEXT_SQUARE-NEXT: ret i1 false
+ %sx = sext i8 %x to i32
+ %mul = mul nsw i32 %sx, %sx
+ %and = and i32 %mul, 1073741824 ; 1 << 30
+ %cmp = icmp ne i32 %and, 0
+ ret i1 %cmp
+}
+
+define i1 @sext_square_bit14(i8 %x) {
+; SEXT_SQUARE-LABEL: @sext_square_bit14(
+; SEXT_SQUARE-NOT: ret i1 false
+ %sx = sext i8 %x to i32
+ %mul = mul nsw i32 %sx, %sx
+ %and = and i32 %mul, 16384 ; 1 << 14
+ %cmp = icmp ne i32 %and, 0
+ ret i1 %cmp
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/153484
More information about the llvm-commits
mailing list