[llvm] [InstCombine] Fold out-of-range bits for squaring signed integers (PR #153484)

via llvm-commits llvm-commits at lists.llvm.org
Thu Aug 14 16:25:52 PDT 2025


https://github.com/Aethezz updated https://github.com/llvm/llvm-project/pull/153484

>From 98e536c12ae76c1e85c53b5ad5aa6b7c7bf33def Mon Sep 17 00:00:00 2001
From: Aethezz <ellisonlao999 at gmail.com>
Date: Wed, 13 Aug 2025 16:05:15 -0400
Subject: [PATCH 1/2] Added additional logic to fold (x * x) masks for
 out-of-range bits

---
 llvm/lib/Analysis/ValueTracking.cpp           | 43 +++++++++++++++++++
 .../test/Analysis/ValueTracking/known-bits.ll | 33 ++++++++++++++
 2 files changed, 76 insertions(+)

diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index af85ce4077ec8..7a973140f6075 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -423,6 +423,49 @@ static void computeKnownBitsMul(const Value *Op0, const Value *Op1, bool NSW,
     Known.makeNonNegative();
   else if (isKnownNegative && !Known.isNonNegative())
     Known.makeNegative();
+
+  // Additional logic: If both operands are the same sign- or zero-extended
+  // value from a small integer, and the multiplication is (sext x) * (sext x)
+  // or (zext x) * (zext x), then the result cannot set bits above the maximum
+  // possible square. This allows InstCombine and other passes to fold (x * x) &
+  // (1 << N) to 0 when N is out of range.
+  using namespace PatternMatch;
+  const Value *A = nullptr;
+  // Only handle the case where both operands are the same extension of the same
+  // value.
+  if ((match(Op0, m_SExt(m_Value(A))) && match(Op1, m_SExt(m_Specific(A)))) ||
+      (match(Op0, m_ZExt(m_Value(A))) && match(Op1, m_ZExt(m_Specific(A))))) {
+    Type *FromTy = A->getType();
+    Type *ToTy = Op0->getType();
+    if (FromTy->isIntegerTy() && ToTy->isIntegerTy() &&
+        FromTy->getScalarSizeInBits() < ToTy->getScalarSizeInBits()) {
+      unsigned FromBits = FromTy->getScalarSizeInBits();
+      unsigned ToBits = ToTy->getScalarSizeInBits();
+      // For both signed and unsigned, the maximum absolute value is max(|min|,
+      // |max|)
+      APInt minVal(FromBits, 0), maxVal(FromBits, 0);
+      bool isSigned = isa<SExtInst>(Op0);
+      if (isSigned) {
+        minVal = APInt::getSignedMinValue(FromBits);
+        maxVal = APInt::getSignedMaxValue(FromBits);
+      } else {
+        minVal = APInt::getMinValue(FromBits);
+        maxVal = APInt::getMaxValue(FromBits);
+      }
+      APInt absMin = minVal.abs();
+      APInt absMax = maxVal.abs();
+      APInt maxAbs = absMin.ugt(absMax) ? absMin : absMax;
+      APInt maxSquare = maxAbs.zext(ToBits);
+      maxSquare = maxSquare * maxSquare;
+      // All bits above the highest set bit in maxSquare are known zero.
+      unsigned MaxBit = maxSquare.isZero() ? 0 : maxSquare.logBase2();
+      if (MaxBit + 1 < ToBits) {
+        APInt KnownZeroMask =
+            APInt::getHighBitsSet(ToBits, ToBits - (MaxBit + 1));
+        Known.Zero |= KnownZeroMask;
+      }
+    }
+  }
 }
 
 void llvm::computeKnownBitsFromRangeMetadata(const MDNode &Ranges,
diff --git a/llvm/test/Analysis/ValueTracking/known-bits.ll b/llvm/test/Analysis/ValueTracking/known-bits.ll
index 5b71402a96f0d..d9f119bd0d146 100644
--- a/llvm/test/Analysis/ValueTracking/known-bits.ll
+++ b/llvm/test/Analysis/ValueTracking/known-bits.ll
@@ -49,3 +49,36 @@ define i1 @vec_reverse_known_bits_demanded_fail(<4 x i8> %xx) {
   %r = icmp slt i8 %ele, 0
   ret i1 %r
 }
+
+; Test known bits for (sext i8 x) * (sext i8 x)
+; RUN: opt -passes=instcombine < %s -S | FileCheck %s --check-prefix=SEXT_SQUARE
+
+define i1 @sext_square_bit31(i8 %x) {
+; SEXT_SQUARE-LABEL: @sext_square_bit31(
+; SEXT_SQUARE-NEXT:    ret i1 false
+  %sx = sext i8 %x to i32
+  %mul = mul nsw i32 %sx, %sx
+  %and = and i32 %mul, 2147483648 ; 1 << 31
+  %cmp = icmp ne i32 %and, 0
+  ret i1 %cmp
+}
+
+define i1 @sext_square_bit30(i8 %x) {
+; SEXT_SQUARE-LABEL: @sext_square_bit30(
+; SEXT_SQUARE-NEXT:    ret i1 false
+  %sx = sext i8 %x to i32
+  %mul = mul nsw i32 %sx, %sx
+  %and = and i32 %mul, 1073741824 ; 1 << 30
+  %cmp = icmp ne i32 %and, 0
+  ret i1 %cmp
+}
+
+define i1 @sext_square_bit14(i8 %x) {
+; SEXT_SQUARE-LABEL: @sext_square_bit14(
+; SEXT_SQUARE-NOT: ret i1 false
+  %sx = sext i8 %x to i32
+  %mul = mul nsw i32 %sx, %sx
+  %and = and i32 %mul, 16384 ; 1 << 14
+  %cmp = icmp ne i32 %and, 0
+  ret i1 %cmp
+}

>From f5f6d15611cb9b286aff53284e020c8e4d901361 Mon Sep 17 00:00:00 2001
From: Aethezz <64500703+Aethezz at users.noreply.github.com>
Date: Thu, 14 Aug 2025 19:25:43 -0400
Subject: [PATCH 2/2] remove redundant namespace

---
 llvm/lib/Analysis/ValueTracking.cpp | 1 -
 1 file changed, 1 deletion(-)

diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index 9fa26edc249ca..feb350a4a4692 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -429,7 +429,6 @@ static void computeKnownBitsMul(const Value *Op0, const Value *Op1, bool NSW,
   // or (zext x) * (zext x), then the result cannot set bits above the maximum
   // possible square. This allows InstCombine and other passes to fold (x * x) &
   // (1 << N) to 0 when N is out of range.
-  using namespace PatternMatch;
   const Value *A = nullptr;
   // Only handle the case where both operands are the same extension of the same
   // value.



More information about the llvm-commits mailing list