[llvm] [InstCombine] Fold out-of-range bits for squaring signed integers (PR #153484)

via llvm-commits llvm-commits at lists.llvm.org
Thu Sep 4 09:15:23 PDT 2025


https://github.com/Aethezz updated https://github.com/llvm/llvm-project/pull/153484

>From 98e536c12ae76c1e85c53b5ad5aa6b7c7bf33def Mon Sep 17 00:00:00 2001
From: Aethezz <ellisonlao999 at gmail.com>
Date: Wed, 13 Aug 2025 16:05:15 -0400
Subject: [PATCH 1/9] Added additional logic to fold (x * x) masks for
 out-of-range bits

---
 llvm/lib/Analysis/ValueTracking.cpp           | 43 +++++++++++++++++++
 .../test/Analysis/ValueTracking/known-bits.ll | 33 ++++++++++++++
 2 files changed, 76 insertions(+)

diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index af85ce4077ec8..7a973140f6075 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -423,6 +423,49 @@ static void computeKnownBitsMul(const Value *Op0, const Value *Op1, bool NSW,
     Known.makeNonNegative();
   else if (isKnownNegative && !Known.isNonNegative())
     Known.makeNegative();
+
+  // Additional logic: If both operands are the same sign- or zero-extended
+  // value from a small integer, and the multiplication is (sext x) * (sext x)
+  // or (zext x) * (zext x), then the result cannot set bits above the maximum
+  // possible square. This allows InstCombine and other passes to fold (x * x) &
+  // (1 << N) to 0 when N is out of range.
+  using namespace PatternMatch;
+  const Value *A = nullptr;
+  // Only handle the case where both operands are the same extension of the same
+  // value.
+  if ((match(Op0, m_SExt(m_Value(A))) && match(Op1, m_SExt(m_Specific(A)))) ||
+      (match(Op0, m_ZExt(m_Value(A))) && match(Op1, m_ZExt(m_Specific(A))))) {
+    Type *FromTy = A->getType();
+    Type *ToTy = Op0->getType();
+    if (FromTy->isIntegerTy() && ToTy->isIntegerTy() &&
+        FromTy->getScalarSizeInBits() < ToTy->getScalarSizeInBits()) {
+      unsigned FromBits = FromTy->getScalarSizeInBits();
+      unsigned ToBits = ToTy->getScalarSizeInBits();
+      // For both signed and unsigned, the maximum absolute value is max(|min|,
+      // |max|)
+      APInt minVal(FromBits, 0), maxVal(FromBits, 0);
+      bool isSigned = isa<SExtInst>(Op0);
+      if (isSigned) {
+        minVal = APInt::getSignedMinValue(FromBits);
+        maxVal = APInt::getSignedMaxValue(FromBits);
+      } else {
+        minVal = APInt::getMinValue(FromBits);
+        maxVal = APInt::getMaxValue(FromBits);
+      }
+      APInt absMin = minVal.abs();
+      APInt absMax = maxVal.abs();
+      APInt maxAbs = absMin.ugt(absMax) ? absMin : absMax;
+      APInt maxSquare = maxAbs.zext(ToBits);
+      maxSquare = maxSquare * maxSquare;
+      // All bits above the highest set bit in maxSquare are known zero.
+      unsigned MaxBit = maxSquare.isZero() ? 0 : maxSquare.logBase2();
+      if (MaxBit + 1 < ToBits) {
+        APInt KnownZeroMask =
+            APInt::getHighBitsSet(ToBits, ToBits - (MaxBit + 1));
+        Known.Zero |= KnownZeroMask;
+      }
+    }
+  }
 }
 
 void llvm::computeKnownBitsFromRangeMetadata(const MDNode &Ranges,
diff --git a/llvm/test/Analysis/ValueTracking/known-bits.ll b/llvm/test/Analysis/ValueTracking/known-bits.ll
index 5b71402a96f0d..d9f119bd0d146 100644
--- a/llvm/test/Analysis/ValueTracking/known-bits.ll
+++ b/llvm/test/Analysis/ValueTracking/known-bits.ll
@@ -49,3 +49,36 @@ define i1 @vec_reverse_known_bits_demanded_fail(<4 x i8> %xx) {
   %r = icmp slt i8 %ele, 0
   ret i1 %r
 }
+
+; Test known bits for (sext i8 x) * (sext i8 x)
+; RUN: opt -passes=instcombine < %s -S | FileCheck %s --check-prefix=SEXT_SQUARE
+
+define i1 @sext_square_bit31(i8 %x) {
+; SEXT_SQUARE-LABEL: @sext_square_bit31(
+; SEXT_SQUARE-NEXT:    ret i1 false
+  %sx = sext i8 %x to i32
+  %mul = mul nsw i32 %sx, %sx
+  %and = and i32 %mul, 2147483648 ; 1 << 31
+  %cmp = icmp ne i32 %and, 0
+  ret i1 %cmp
+}
+
+define i1 @sext_square_bit30(i8 %x) {
+; SEXT_SQUARE-LABEL: @sext_square_bit30(
+; SEXT_SQUARE-NEXT:    ret i1 false
+  %sx = sext i8 %x to i32
+  %mul = mul nsw i32 %sx, %sx
+  %and = and i32 %mul, 1073741824 ; 1 << 30
+  %cmp = icmp ne i32 %and, 0
+  ret i1 %cmp
+}
+
+define i1 @sext_square_bit14(i8 %x) {
+; SEXT_SQUARE-LABEL: @sext_square_bit14(
+; SEXT_SQUARE-NOT: ret i1 false
+  %sx = sext i8 %x to i32
+  %mul = mul nsw i32 %sx, %sx
+  %and = and i32 %mul, 16384 ; 1 << 14
+  %cmp = icmp ne i32 %and, 0
+  ret i1 %cmp
+}

>From f5f6d15611cb9b286aff53284e020c8e4d901361 Mon Sep 17 00:00:00 2001
From: Aethezz <64500703+Aethezz at users.noreply.github.com>
Date: Thu, 14 Aug 2025 19:25:43 -0400
Subject: [PATCH 2/9] remove redundant namespace

---
 llvm/lib/Analysis/ValueTracking.cpp | 1 -
 1 file changed, 1 deletion(-)

diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index 9fa26edc249ca..feb350a4a4692 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -429,7 +429,6 @@ static void computeKnownBitsMul(const Value *Op0, const Value *Op1, bool NSW,
   // or (zext x) * (zext x), then the result cannot set bits above the maximum
   // possible square. This allows InstCombine and other passes to fold (x * x) &
   // (1 << N) to 0 when N is out of range.
-  using namespace PatternMatch;
   const Value *A = nullptr;
   // Only handle the case where both operands are the same extension of the same
   // value.

>From 246d414e9d6e8e0fe498b4bd8c32e7473f6ea6ac Mon Sep 17 00:00:00 2001
From: Aethezz <ellisonlao999 at gmail.com>
Date: Mon, 18 Aug 2025 17:26:42 -0400
Subject: [PATCH 3/9] use ComputeNumSignBits() instead and remove zext handling

---
 llvm/lib/Analysis/ValueTracking.cpp | 41 +++++------------------------
 1 file changed, 7 insertions(+), 34 deletions(-)

diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index feb350a4a4692..f332272bfe8f5 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -430,40 +430,13 @@ static void computeKnownBitsMul(const Value *Op0, const Value *Op1, bool NSW,
   // possible square. This allows InstCombine and other passes to fold (x * x) &
   // (1 << N) to 0 when N is out of range.
   const Value *A = nullptr;
-  // Only handle the case where both operands are the same extension of the same
-  // value.
-  if ((match(Op0, m_SExt(m_Value(A))) && match(Op1, m_SExt(m_Specific(A)))) ||
-      (match(Op0, m_ZExt(m_Value(A))) && match(Op1, m_ZExt(m_Specific(A))))) {
-    Type *FromTy = A->getType();
-    Type *ToTy = Op0->getType();
-    if (FromTy->isIntegerTy() && ToTy->isIntegerTy() &&
-        FromTy->getScalarSizeInBits() < ToTy->getScalarSizeInBits()) {
-      unsigned FromBits = FromTy->getScalarSizeInBits();
-      unsigned ToBits = ToTy->getScalarSizeInBits();
-      // For both signed and unsigned, the maximum absolute value is max(|min|,
-      // |max|)
-      APInt minVal(FromBits, 0), maxVal(FromBits, 0);
-      bool isSigned = isa<SExtInst>(Op0);
-      if (isSigned) {
-        minVal = APInt::getSignedMinValue(FromBits);
-        maxVal = APInt::getSignedMaxValue(FromBits);
-      } else {
-        minVal = APInt::getMinValue(FromBits);
-        maxVal = APInt::getMaxValue(FromBits);
-      }
-      APInt absMin = minVal.abs();
-      APInt absMax = maxVal.abs();
-      APInt maxAbs = absMin.ugt(absMax) ? absMin : absMax;
-      APInt maxSquare = maxAbs.zext(ToBits);
-      maxSquare = maxSquare * maxSquare;
-      // All bits above the highest set bit in maxSquare are known zero.
-      unsigned MaxBit = maxSquare.isZero() ? 0 : maxSquare.logBase2();
-      if (MaxBit + 1 < ToBits) {
-        APInt KnownZeroMask =
-            APInt::getHighBitsSet(ToBits, ToBits - (MaxBit + 1));
-        Known.Zero |= KnownZeroMask;
-      }
-    }
+
+  if (match(Op0, m_SExt(m_Value(A))) && match(Op1, m_SExt(m_Specific(A)))) {
+    // Product of (sext x) * (sext x) is always non-negative.
+    // So we know the sign bit itself is zero.
+    unsigned SignBits = ComputeNumSignBits(Op0, Q, Depth);
+    if (SignBits > 1)
+      Known.Zero.setHighBits(SignBits - 1);
   }
 }
 

>From 0c58e22c82a1c6a550905aaa06741607a0b61539 Mon Sep 17 00:00:00 2001
From: Aethezz <ellisonlao999 at gmail.com>
Date: Mon, 18 Aug 2025 17:32:56 -0400
Subject: [PATCH 4/9] remove previous comment

---
 llvm/lib/Analysis/ValueTracking.cpp | 6 +-----
 1 file changed, 1 insertion(+), 5 deletions(-)

diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index f332272bfe8f5..0607ef3c9dec4 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -424,11 +424,7 @@ static void computeKnownBitsMul(const Value *Op0, const Value *Op1, bool NSW,
   else if (isKnownNegative && !Known.isNonNegative())
     Known.makeNegative();
 
-  // Additional logic: If both operands are the same sign- or zero-extended
-  // value from a small integer, and the multiplication is (sext x) * (sext x)
-  // or (zext x) * (zext x), then the result cannot set bits above the maximum
-  // possible square. This allows InstCombine and other passes to fold (x * x) &
-  // (1 << N) to 0 when N is out of range.
+  // Check if both operands are the same sign-extension of a single value.
   const Value *A = nullptr;
 
   if (match(Op0, m_SExt(m_Value(A))) && match(Op1, m_SExt(m_Specific(A)))) {

>From 0ff6997e4fa39ed13b27fbeef279c80a33ed0545 Mon Sep 17 00:00:00 2001
From: Aethezz <ellisonlao999 at gmail.com>
Date: Mon, 18 Aug 2025 21:44:30 -0400
Subject: [PATCH 5/9] temporary revert back to previous changes and remove zext
 handling

---
 llvm/lib/Analysis/ValueTracking.cpp | 28 +++++++++++++++++++++++-----
 1 file changed, 23 insertions(+), 5 deletions(-)

diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index 0607ef3c9dec4..5581a89cb1d35 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -426,13 +426,31 @@ static void computeKnownBitsMul(const Value *Op0, const Value *Op1, bool NSW,
 
   // Check if both operands are the same sign-extension of a single value.
   const Value *A = nullptr;
-
   if (match(Op0, m_SExt(m_Value(A))) && match(Op1, m_SExt(m_Specific(A)))) {
     // Product of (sext x) * (sext x) is always non-negative.
-    // So we know the sign bit itself is zero.
-    unsigned SignBits = ComputeNumSignBits(Op0, Q, Depth);
-    if (SignBits > 1)
-      Known.Zero.setHighBits(SignBits - 1);
+    // Compute the maximum possible square and fold all out-of-range bits.
+    Type *FromTy = A->getType();
+    Type *ToTy = Op0->getType();
+    if (FromTy->isIntegerTy() && ToTy->isIntegerTy() &&
+        FromTy->getScalarSizeInBits() < ToTy->getScalarSizeInBits()) {
+      unsigned FromBits = FromTy->getScalarSizeInBits();
+      unsigned ToBits = ToTy->getScalarSizeInBits();
+      // For signed, the maximum absolute value is max(|min|, |max|)
+      APInt minVal = APInt::getSignedMinValue(FromBits);
+      APInt maxVal = APInt::getSignedMaxValue(FromBits);
+      APInt absMin = minVal.abs();
+      APInt absMax = maxVal.abs();
+      APInt maxAbs = absMin.ugt(absMax) ? absMin : absMax;
+      APInt maxSquare = maxAbs.zext(ToBits);
+      maxSquare = maxSquare * maxSquare;
+      // All bits above the highest set bit in maxSquare are known zero.
+      unsigned MaxBit = maxSquare.isZero() ? 0 : maxSquare.logBase2();
+      if (MaxBit + 1 < ToBits) {
+        APInt KnownZeroMask =
+            APInt::getHighBitsSet(ToBits, ToBits - (MaxBit + 1));
+        Known.Zero |= KnownZeroMask;
+      }
+    }
   }
 }
 

>From 72cd12520a77275cd57135cbe576c8ebf9eb85f9 Mon Sep 17 00:00:00 2001
From: Aethezz <ellisonlao999 at gmail.com>
Date: Tue, 19 Aug 2025 13:42:23 -0400
Subject: [PATCH 6/9] Added logic to compute max number of valid and sign bits
 and set to zero

---
 llvm/lib/Analysis/ValueTracking.cpp | 33 +++++++++--------------------
 1 file changed, 10 insertions(+), 23 deletions(-)

diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index 5581a89cb1d35..993965a33bd79 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -427,29 +427,16 @@ static void computeKnownBitsMul(const Value *Op0, const Value *Op1, bool NSW,
   // Check if both operands are the same sign-extension of a single value.
   const Value *A = nullptr;
   if (match(Op0, m_SExt(m_Value(A))) && match(Op1, m_SExt(m_Specific(A)))) {
-    // Product of (sext x) * (sext x) is always non-negative.
-    // Compute the maximum possible square and fold all out-of-range bits.
-    Type *FromTy = A->getType();
-    Type *ToTy = Op0->getType();
-    if (FromTy->isIntegerTy() && ToTy->isIntegerTy() &&
-        FromTy->getScalarSizeInBits() < ToTy->getScalarSizeInBits()) {
-      unsigned FromBits = FromTy->getScalarSizeInBits();
-      unsigned ToBits = ToTy->getScalarSizeInBits();
-      // For signed, the maximum absolute value is max(|min|, |max|)
-      APInt minVal = APInt::getSignedMinValue(FromBits);
-      APInt maxVal = APInt::getSignedMaxValue(FromBits);
-      APInt absMin = minVal.abs();
-      APInt absMax = maxVal.abs();
-      APInt maxAbs = absMin.ugt(absMax) ? absMin : absMax;
-      APInt maxSquare = maxAbs.zext(ToBits);
-      maxSquare = maxSquare * maxSquare;
-      // All bits above the highest set bit in maxSquare are known zero.
-      unsigned MaxBit = maxSquare.isZero() ? 0 : maxSquare.logBase2();
-      if (MaxBit + 1 < ToBits) {
-        APInt KnownZeroMask =
-            APInt::getHighBitsSet(ToBits, ToBits - (MaxBit + 1));
-        Known.Zero |= KnownZeroMask;
-      }
+    unsigned SignBits = ComputeNumSignBits(Op0, DemandedElts, Q, Depth + 1);
+    unsigned TyBits = Op0->getType()->getScalarSizeInBits();
+    // The output of the Mul can be at most twice the valid bits
+    unsigned OutValidBits = 2 * (TyBits - SignBits + 1);
+    unsigned OutSignBits =
+        OutValidBits > TyBits ? 1 : TyBits - OutValidBits + 1;
+
+    if (OutSignBits > 1) {
+      APInt KnownZeroMask = APInt::getHighBitsSet(TyBits, OutSignBits);
+      Known.Zero |= KnownZeroMask;
     }
   }
 }

>From 514f2670e3bb92e5f119d845b298d7b92e597e46 Mon Sep 17 00:00:00 2001
From: Aethezz <ellisonlao999 at gmail.com>
Date: Tue, 19 Aug 2025 17:34:42 -0400
Subject: [PATCH 7/9] removed match and moved into selfmultiply condition,
 simplified condition to drop ternary and moved test to
 llvm/test/Transforms/InstCombine since wasn't folded with instsimplifiy

---
 llvm/lib/Analysis/ValueTracking.cpp           | 34 +++++++++----------
 .../test/Analysis/ValueTracking/known-bits.ll | 33 ------------------
 llvm/test/Transforms/InstCombine/sext.ll      | 32 +++++++++++++++++
 3 files changed, 48 insertions(+), 51 deletions(-)

diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index 993965a33bd79..f58c82d1ce9a6 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -409,10 +409,24 @@ static void computeKnownBitsMul(const Value *Op0, const Value *Op1, bool NSW,
   }
 
   bool SelfMultiply = Op0 == Op1;
-  if (SelfMultiply)
+  if (SelfMultiply) {
     SelfMultiply &=
         isGuaranteedNotToBeUndef(Op0, Q.AC, Q.CxtI, Q.DT, Depth + 1);
-  Known = KnownBits::mul(Known, Known2, SelfMultiply);
+
+    Known = KnownBits::mul(Known, Known2, SelfMultiply);
+
+    unsigned SignBits = ComputeNumSignBits(Op0, DemandedElts, Q, Depth + 1);
+    unsigned TyBits = Op0->getType()->getScalarSizeInBits();
+    unsigned OutValidBits = 2 * (TyBits - SignBits + 1);
+
+    if (OutValidBits < TyBits) {
+      APInt KnownZeroMask =
+          APInt::getHighBitsSet(TyBits, TyBits - OutValidBits + 1);
+      Known.Zero |= KnownZeroMask;
+    }
+  } else {
+    Known = KnownBits::mul(Known, Known2, SelfMultiply);
+  }
 
   // Only make use of no-wrap flags if we failed to compute the sign bit
   // directly.  This matters if the multiplication always overflows, in
@@ -423,22 +437,6 @@ static void computeKnownBitsMul(const Value *Op0, const Value *Op1, bool NSW,
     Known.makeNonNegative();
   else if (isKnownNegative && !Known.isNonNegative())
     Known.makeNegative();
-
-  // Check if both operands are the same sign-extension of a single value.
-  const Value *A = nullptr;
-  if (match(Op0, m_SExt(m_Value(A))) && match(Op1, m_SExt(m_Specific(A)))) {
-    unsigned SignBits = ComputeNumSignBits(Op0, DemandedElts, Q, Depth + 1);
-    unsigned TyBits = Op0->getType()->getScalarSizeInBits();
-    // The output of the Mul can be at most twice the valid bits
-    unsigned OutValidBits = 2 * (TyBits - SignBits + 1);
-    unsigned OutSignBits =
-        OutValidBits > TyBits ? 1 : TyBits - OutValidBits + 1;
-
-    if (OutSignBits > 1) {
-      APInt KnownZeroMask = APInt::getHighBitsSet(TyBits, OutSignBits);
-      Known.Zero |= KnownZeroMask;
-    }
-  }
 }
 
 void llvm::computeKnownBitsFromRangeMetadata(const MDNode &Ranges,
diff --git a/llvm/test/Analysis/ValueTracking/known-bits.ll b/llvm/test/Analysis/ValueTracking/known-bits.ll
index d9f119bd0d146..5b71402a96f0d 100644
--- a/llvm/test/Analysis/ValueTracking/known-bits.ll
+++ b/llvm/test/Analysis/ValueTracking/known-bits.ll
@@ -49,36 +49,3 @@ define i1 @vec_reverse_known_bits_demanded_fail(<4 x i8> %xx) {
   %r = icmp slt i8 %ele, 0
   ret i1 %r
 }
-
-; Test known bits for (sext i8 x) * (sext i8 x)
-; RUN: opt -passes=instcombine < %s -S | FileCheck %s --check-prefix=SEXT_SQUARE
-
-define i1 @sext_square_bit31(i8 %x) {
-; SEXT_SQUARE-LABEL: @sext_square_bit31(
-; SEXT_SQUARE-NEXT:    ret i1 false
-  %sx = sext i8 %x to i32
-  %mul = mul nsw i32 %sx, %sx
-  %and = and i32 %mul, 2147483648 ; 1 << 31
-  %cmp = icmp ne i32 %and, 0
-  ret i1 %cmp
-}
-
-define i1 @sext_square_bit30(i8 %x) {
-; SEXT_SQUARE-LABEL: @sext_square_bit30(
-; SEXT_SQUARE-NEXT:    ret i1 false
-  %sx = sext i8 %x to i32
-  %mul = mul nsw i32 %sx, %sx
-  %and = and i32 %mul, 1073741824 ; 1 << 30
-  %cmp = icmp ne i32 %and, 0
-  ret i1 %cmp
-}
-
-define i1 @sext_square_bit14(i8 %x) {
-; SEXT_SQUARE-LABEL: @sext_square_bit14(
-; SEXT_SQUARE-NOT: ret i1 false
-  %sx = sext i8 %x to i32
-  %mul = mul nsw i32 %sx, %sx
-  %and = and i32 %mul, 16384 ; 1 << 14
-  %cmp = icmp ne i32 %and, 0
-  ret i1 %cmp
-}
diff --git a/llvm/test/Transforms/InstCombine/sext.ll b/llvm/test/Transforms/InstCombine/sext.ll
index ee3c52259f930..276543790406d 100644
--- a/llvm/test/Transforms/InstCombine/sext.ll
+++ b/llvm/test/Transforms/InstCombine/sext.ll
@@ -423,3 +423,35 @@ define i64 @smear_set_bit_different_dest_type_wider_dst(i32 %x) {
   %s = sext i8 %a to i64
   ret i64 %s
 }
+
+; Test known bits for (sext i8 x) * (sext i8 x)
+
+define i1 @sext_square_bit30(i8 %x) {
+; CHECK-LABEL: @sext_square_bit30(
+; CHECK:  ret i1 false
+  %sx = sext i8 %x to i32
+  %mul = mul nsw i32 %sx, %sx
+  %and = and i32 %mul, 1073741824 ; 1 << 30
+  %cmp = icmp ne i32 %and, 0
+  ret i1 %cmp
+}
+
+define i1 @sext_square_bit15(i8 %x) {
+; CHECK-LABEL: @sext_square_bit15(
+; CHECK:  ret i1 false
+  %sx = sext i8 %x to i32
+  %mul = mul nsw i32 %sx, %sx
+  %and = and i32 %mul, 32768 ; 1 << 15
+  %cmp = icmp ne i32 %and, 0
+  ret i1 %cmp
+}
+
+define i1 @sext_square_bit14(i8 %x) {
+; CHECK-LABEL: @sext_square_bit14(
+; CHECK-NOT: ret i1 false
+  %sx = sext i8 %x to i32
+  %mul = mul nsw i32 %sx, %sx
+  %and = and i32 %mul, 16384 ; 1 << 14
+  %cmp = icmp ne i32 %and, 0
+  ret i1 %cmp
+}

>From 54acb353d8515ab5fccc1ee3e9d19d95f70d6e72 Mon Sep 17 00:00:00 2001
From: Aethezz <ellisonlao999 at gmail.com>
Date: Wed, 20 Aug 2025 12:49:09 -0400
Subject: [PATCH 8/9] fixed formatting on testcase

---
 llvm/test/Transforms/InstCombine/sext.ll | 12 +++++++++---
 1 file changed, 9 insertions(+), 3 deletions(-)

diff --git a/llvm/test/Transforms/InstCombine/sext.ll b/llvm/test/Transforms/InstCombine/sext.ll
index 276543790406d..aef2b1074d297 100644
--- a/llvm/test/Transforms/InstCombine/sext.ll
+++ b/llvm/test/Transforms/InstCombine/sext.ll
@@ -428,7 +428,8 @@ define i64 @smear_set_bit_different_dest_type_wider_dst(i32 %x) {
 
 define i1 @sext_square_bit30(i8 %x) {
 ; CHECK-LABEL: @sext_square_bit30(
-; CHECK:  ret i1 false
+; CHECK-NEXT:    ret i1 false
+;
   %sx = sext i8 %x to i32
   %mul = mul nsw i32 %sx, %sx
   %and = and i32 %mul, 1073741824 ; 1 << 30
@@ -438,7 +439,8 @@ define i1 @sext_square_bit30(i8 %x) {
 
 define i1 @sext_square_bit15(i8 %x) {
 ; CHECK-LABEL: @sext_square_bit15(
-; CHECK:  ret i1 false
+; CHECK-NEXT:    ret i1 false
+;
   %sx = sext i8 %x to i32
   %mul = mul nsw i32 %sx, %sx
   %and = and i32 %mul, 32768 ; 1 << 15
@@ -448,7 +450,11 @@ define i1 @sext_square_bit15(i8 %x) {
 
 define i1 @sext_square_bit14(i8 %x) {
 ; CHECK-LABEL: @sext_square_bit14(
-; CHECK-NOT: ret i1 false
+; CHECK-NEXT:    [[SX:%.*]] = sext i8 [[X:%.*]] to i32
+; CHECK-NEXT:    [[MUL:%.*]] = mul nsw i32 [[SX]], [[SX]]
+; CHECK-NEXT:    [[CMP:%.*]] = icmp samesign ugt i32 [[MUL]], 16383
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
   %sx = sext i8 %x to i32
   %mul = mul nsw i32 %sx, %sx
   %and = and i32 %mul, 16384 ; 1 << 14

>From 579b5103fbfac9c5a929b9c9da0da5a0d4fdb2fc Mon Sep 17 00:00:00 2001
From: Aethezz <ellisonlao999 at gmail.com>
Date: Thu, 4 Sep 2025 12:15:05 -0400
Subject: [PATCH 9/9] made two if statements and added noundef to testcases

---
 llvm/lib/Analysis/ValueTracking.cpp      | 8 +++-----
 llvm/test/Transforms/InstCombine/sext.ll | 6 +++---
 2 files changed, 6 insertions(+), 8 deletions(-)

diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index f58c82d1ce9a6..8996ef2488a25 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -409,12 +409,12 @@ static void computeKnownBitsMul(const Value *Op0, const Value *Op1, bool NSW,
   }
 
   bool SelfMultiply = Op0 == Op1;
-  if (SelfMultiply) {
+  if (SelfMultiply)
     SelfMultiply &=
         isGuaranteedNotToBeUndef(Op0, Q.AC, Q.CxtI, Q.DT, Depth + 1);
+  Known = KnownBits::mul(Known, Known2, SelfMultiply);
 
-    Known = KnownBits::mul(Known, Known2, SelfMultiply);
-
+  if (SelfMultiply) {
     unsigned SignBits = ComputeNumSignBits(Op0, DemandedElts, Q, Depth + 1);
     unsigned TyBits = Op0->getType()->getScalarSizeInBits();
     unsigned OutValidBits = 2 * (TyBits - SignBits + 1);
@@ -424,8 +424,6 @@ static void computeKnownBitsMul(const Value *Op0, const Value *Op1, bool NSW,
           APInt::getHighBitsSet(TyBits, TyBits - OutValidBits + 1);
       Known.Zero |= KnownZeroMask;
     }
-  } else {
-    Known = KnownBits::mul(Known, Known2, SelfMultiply);
   }
 
   // Only make use of no-wrap flags if we failed to compute the sign bit
diff --git a/llvm/test/Transforms/InstCombine/sext.ll b/llvm/test/Transforms/InstCombine/sext.ll
index aef2b1074d297..c72614d526036 100644
--- a/llvm/test/Transforms/InstCombine/sext.ll
+++ b/llvm/test/Transforms/InstCombine/sext.ll
@@ -426,7 +426,7 @@ define i64 @smear_set_bit_different_dest_type_wider_dst(i32 %x) {
 
 ; Test known bits for (sext i8 x) * (sext i8 x)
 
-define i1 @sext_square_bit30(i8 %x) {
+define i1 @sext_square_bit30(i8 noundef %x) {
 ; CHECK-LABEL: @sext_square_bit30(
 ; CHECK-NEXT:    ret i1 false
 ;
@@ -437,7 +437,7 @@ define i1 @sext_square_bit30(i8 %x) {
   ret i1 %cmp
 }
 
-define i1 @sext_square_bit15(i8 %x) {
+define i1 @sext_square_bit15(i8 noundef %x) {
 ; CHECK-LABEL: @sext_square_bit15(
 ; CHECK-NEXT:    ret i1 false
 ;
@@ -448,7 +448,7 @@ define i1 @sext_square_bit15(i8 %x) {
   ret i1 %cmp
 }
 
-define i1 @sext_square_bit14(i8 %x) {
+define i1 @sext_square_bit14(i8 noundef %x) {
 ; CHECK-LABEL: @sext_square_bit14(
 ; CHECK-NEXT:    [[SX:%.*]] = sext i8 [[X:%.*]] to i32
 ; CHECK-NEXT:    [[MUL:%.*]] = mul nsw i32 [[SX]], [[SX]]



More information about the llvm-commits mailing list