[llvm] [InstCombineCompares] Try to "strengthen" compares based on known bits. (PR #79405)

Mikhail Gudim via llvm-commits llvm-commits at lists.llvm.org
Wed Jan 24 20:32:38 PST 2024


https://github.com/mgudim created https://github.com/llvm/llvm-project/pull/79405

For example, replace `icmp ugt %x, 14` with `icmp ugt %x, 15` when it is known that the two least significant bits of `%x` is zero.

>From df2bae8822ebf9db11e9dcae16c9bb7c760476f1 Mon Sep 17 00:00:00 2001
From: Mikhail Gudim <mgudim at gmail.com>
Date: Tue, 16 Jan 2024 03:58:34 -0500
Subject: [PATCH] [InstCombineCompares] Try to "strengthen" compares based on
 known bits.

For example, replace `icmp ugt %x, 14` with `icmp ugt %x, 15` when
it is known that the two least significant bits of `%x` is zero.
---
 .../InstCombine/InstCombineCompares.cpp       | 72 ++++++++++++++++
 llvm/test/Transforms/InstCombine/icmp.ll      | 82 +++++++++++++++----
 2 files changed, 138 insertions(+), 16 deletions(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 3875e59c3ede3b..30b421381de11c 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -6345,6 +6345,78 @@ Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) {
        (Op0Known.One.isNegative() && Op1Known.One.isNegative())))
     return new ICmpInst(I.getUnsignedPredicate(), Op0, Op1);
 
+  // Try to "strengthen" the RHS of compare based on known bits.
+  // For example, replace `icmp ugt %x, 14` with `icmp ugt %x, 15` when
+  // it is known that the two least significant bits of `%x` is zero.
+  if (Op1Known.isConstant() && Op0Known.Zero.isMask()) {
+    APInt RHSConst = Op1Known.getConstant();
+    ConstantRange Op0PredRange =
+        ConstantRange::makeExactICmpRegion(Pred, RHSConst);
+    int KnownZeroMaskLength = BitWidth - Op0Known.Zero.countLeadingZeros();
+    if (KnownZeroMaskLength > 0) {
+      APInt PowOf2(BitWidth, 1 << KnownZeroMaskLength);
+      APInt Op0PredMin(BitWidth, 0);
+      APInt Op0PredMax(BitWidth, 0);
+      APInt Op0MinRefinedByKnownBits(BitWidth, 0);
+      APInt Op0MaxRefinedByKnownBits(BitWidth, 0);
+      APInt NewLower(BitWidth, 0);
+      APInt NewUpper(BitWidth, 0);
+      bool ImprovedLower = false;
+      bool ImprovedUpper = false;
+      if (I.isSigned()) {
+        Op0PredMin = Op0PredRange.getSignedMin();
+        Op0PredMax = Op0PredRange.getSignedMax();
+        // Compute the smallest number satisfying the known-bits constrained
+        // which is at greater or equal Op0PredMin.
+        Op0MinRefinedByKnownBits =
+            PowOf2 *
+            APIntOps::RoundingSDiv(Op0PredMin, PowOf2, APInt::Rounding::UP);
+        // Compute the largest number satisfying the known-bits constrained
+        // which is at less or equal Op0PredMax.
+        Op0MaxRefinedByKnownBits =
+            PowOf2 *
+            APIntOps::RoundingSDiv(Op0PredMax, PowOf2, APInt::Rounding::DOWN);
+        NewLower = APIntOps::smax(Op0MinRefinedByKnownBits, Op0PredMin);
+        NewUpper = APIntOps::smin(Op0MaxRefinedByKnownBits, Op0PredMax);
+        ImprovedLower = NewLower.sgt(Op0PredMin);
+        ImprovedUpper = NewUpper.slt(Op0PredMax);
+      } else {
+        Op0PredMin = Op0PredRange.getUnsignedMin();
+        Op0PredMax = Op0PredRange.getUnsignedMax();
+        Op0MinRefinedByKnownBits =
+            PowOf2 *
+            APIntOps::RoundingUDiv(Op0PredMin, PowOf2, APInt::Rounding::UP);
+        Op0MaxRefinedByKnownBits =
+            PowOf2 *
+            APIntOps::RoundingUDiv(Op0PredMax, PowOf2, APInt::Rounding::DOWN);
+        NewLower = APIntOps::umax(Op0MinRefinedByKnownBits, Op0PredMin);
+        NewUpper = APIntOps::umin(Op0MaxRefinedByKnownBits, Op0PredMax);
+        ImprovedLower = NewLower.ugt(Op0PredMin);
+        ImprovedUpper = NewUpper.ult(Op0PredMax);
+      }
+
+      // Non-strict inequalities should have been canonicalized to strict ones
+      // by now.
+      switch (Pred) {
+      default:
+        break;
+      case ICmpInst::ICMP_ULT:
+      case ICmpInst::ICMP_SLT: {
+        if (ImprovedUpper)
+          return new ICmpInst(Pred, Op0,
+                              ConstantInt::get(Op1->getType(), NewUpper + 1));
+        break;
+      }
+      case ICmpInst::ICMP_UGT:
+      case ICmpInst::ICMP_SGT: {
+        if (ImprovedLower)
+          return new ICmpInst(Pred, Op0,
+                              ConstantInt::get(Op1->getType(), NewLower - 1));
+        break;
+      }
+      }
+    }
+  }
   return nullptr;
 }
 
diff --git a/llvm/test/Transforms/InstCombine/icmp.ll b/llvm/test/Transforms/InstCombine/icmp.ll
index 1f554c7b60256c..339d66e8c2e437 100644
--- a/llvm/test/Transforms/InstCombine/icmp.ll
+++ b/llvm/test/Transforms/InstCombine/icmp.ll
@@ -1445,8 +1445,8 @@ define <2 x i1> @test70vec(<2 x i32> %X) {
 
 define i1 @icmp_sext16trunc(i32 %x) {
 ; CHECK-LABEL: @icmp_sext16trunc(
-; CHECK-NEXT:    [[TMP1:%.*]] = trunc i32 [[X:%.*]] to i16
-; CHECK-NEXT:    [[CMP:%.*]] = icmp slt i16 [[TMP1]], 36
+; CHECK-NEXT:    [[SEXT1:%.*]] = shl i32 [[X:%.*]], 16
+; CHECK-NEXT:    [[CMP:%.*]] = icmp slt i32 [[SEXT1]], 2293761
 ; CHECK-NEXT:    ret i1 [[CMP]]
 ;
   %trunc = trunc i32 %x to i16
@@ -1457,8 +1457,8 @@ define i1 @icmp_sext16trunc(i32 %x) {
 
 define i1 @icmp_sext8trunc(i32 %x) {
 ; CHECK-LABEL: @icmp_sext8trunc(
-; CHECK-NEXT:    [[TMP1:%.*]] = trunc i32 [[X:%.*]] to i8
-; CHECK-NEXT:    [[CMP:%.*]] = icmp slt i8 [[TMP1]], 36
+; CHECK-NEXT:    [[SEXT1:%.*]] = shl i32 [[X:%.*]], 24
+; CHECK-NEXT:    [[CMP:%.*]] = icmp slt i32 [[SEXT1]], 587202561
 ; CHECK-NEXT:    ret i1 [[CMP]]
 ;
   %trunc = trunc i32 %x to i8
@@ -1470,8 +1470,8 @@ define i1 @icmp_sext8trunc(i32 %x) {
 ; Vectors should fold the same way.
 define <2 x i1> @icmp_sext8trunc_vec(<2 x i32> %x) {
 ; CHECK-LABEL: @icmp_sext8trunc_vec(
-; CHECK-NEXT:    [[TMP1:%.*]] = trunc <2 x i32> [[X:%.*]] to <2 x i8>
-; CHECK-NEXT:    [[CMP:%.*]] = icmp slt <2 x i8> [[TMP1]], <i8 36, i8 36>
+; CHECK-NEXT:    [[TMP1:%.*]] = shl <2 x i32> [[X:%.*]], <i32 24, i32 24>
+; CHECK-NEXT:    [[CMP:%.*]] = icmp slt <2 x i32> [[TMP1]], <i32 587202561, i32 587202561>
 ; CHECK-NEXT:    ret <2 x i1> [[CMP]]
 ;
   %trunc = trunc <2 x i32> %x to <2 x i8>
@@ -1482,8 +1482,8 @@ define <2 x i1> @icmp_sext8trunc_vec(<2 x i32> %x) {
 
 define i1 @icmp_shl16(i32 %x) {
 ; CHECK-LABEL: @icmp_shl16(
-; CHECK-NEXT:    [[TMP1:%.*]] = trunc i32 [[X:%.*]] to i16
-; CHECK-NEXT:    [[CMP:%.*]] = icmp slt i16 [[TMP1]], 36
+; CHECK-NEXT:    [[SHL:%.*]] = shl i32 [[X:%.*]], 16
+; CHECK-NEXT:    [[CMP:%.*]] = icmp slt i32 [[SHL]], 2293761
 ; CHECK-NEXT:    ret i1 [[CMP]]
 ;
   %shl = shl i32 %x, 16
@@ -1496,7 +1496,7 @@ define i1 @icmp_shl16(i32 %x) {
 define i1 @icmp_shl17(i32 %x) {
 ; CHECK-LABEL: @icmp_shl17(
 ; CHECK-NEXT:    [[SHL:%.*]] = shl i32 [[X:%.*]], 17
-; CHECK-NEXT:    [[CMP:%.*]] = icmp slt i32 [[SHL]], 2359296
+; CHECK-NEXT:    [[CMP:%.*]] = icmp slt i32 [[SHL]], 2228225
 ; CHECK-NEXT:    ret i1 [[CMP]]
 ;
   %shl = shl i32 %x, 17
@@ -1506,8 +1506,8 @@ define i1 @icmp_shl17(i32 %x) {
 
 define <2 x i1> @icmp_shl16_vec(<2 x i32> %x) {
 ; CHECK-LABEL: @icmp_shl16_vec(
-; CHECK-NEXT:    [[TMP1:%.*]] = trunc <2 x i32> [[X:%.*]] to <2 x i16>
-; CHECK-NEXT:    [[CMP:%.*]] = icmp slt <2 x i16> [[TMP1]], <i16 36, i16 36>
+; CHECK-NEXT:    [[SHL:%.*]] = shl <2 x i32> [[X:%.*]], <i32 16, i32 16>
+; CHECK-NEXT:    [[CMP:%.*]] = icmp slt <2 x i32> [[SHL]], <i32 2293761, i32 2293761>
 ; CHECK-NEXT:    ret <2 x i1> [[CMP]]
 ;
   %shl = shl <2 x i32> %x, <i32 16, i32 16>
@@ -1517,8 +1517,8 @@ define <2 x i1> @icmp_shl16_vec(<2 x i32> %x) {
 
 define i1 @icmp_shl24(i32 %x) {
 ; CHECK-LABEL: @icmp_shl24(
-; CHECK-NEXT:    [[TMP1:%.*]] = trunc i32 [[X:%.*]] to i8
-; CHECK-NEXT:    [[CMP:%.*]] = icmp slt i8 [[TMP1]], 36
+; CHECK-NEXT:    [[SHL:%.*]] = shl i32 [[X:%.*]], 24
+; CHECK-NEXT:    [[CMP:%.*]] = icmp slt i32 [[SHL]], 587202561
 ; CHECK-NEXT:    ret i1 [[CMP]]
 ;
   %shl = shl i32 %x, 24
@@ -2154,7 +2154,7 @@ define i1 @icmp_ashr_and_overshift(i8 %X) {
 define i1 @icmp_and_ashr_neg_and_legal(i8 %x) {
 ; CHECK-LABEL: @icmp_and_ashr_neg_and_legal(
 ; CHECK-NEXT:    [[TMP1:%.*]] = and i8 [[X:%.*]], -32
-; CHECK-NEXT:    [[CMP:%.*]] = icmp slt i8 [[TMP1]], 16
+; CHECK-NEXT:    [[CMP:%.*]] = icmp slt i8 [[TMP1]], 1
 ; CHECK-NEXT:    ret i1 [[CMP]]
 ;
   %ashr = ashr i8 %x, 4
@@ -2180,7 +2180,7 @@ define i1 @icmp_and_ashr_mixed_and_shiftout(i8 %x) {
 define i1 @icmp_and_ashr_neg_cmp_slt_legal(i8 %x) {
 ; CHECK-LABEL: @icmp_and_ashr_neg_cmp_slt_legal(
 ; CHECK-NEXT:    [[TMP1:%.*]] = and i8 [[X:%.*]], -32
-; CHECK-NEXT:    [[CMP:%.*]] = icmp slt i8 [[TMP1]], -64
+; CHECK-NEXT:    [[CMP:%.*]] = icmp slt i8 [[TMP1]], -95
 ; CHECK-NEXT:    ret i1 [[CMP]]
 ;
   %ashr = ashr i8 %x, 4
@@ -2194,7 +2194,7 @@ define i1 @icmp_and_ashr_neg_cmp_slt_shiftout(i8 %x) {
 ; CHECK-LABEL: @icmp_and_ashr_neg_cmp_slt_shiftout(
 ; CHECK-NEXT:    [[ASHR:%.*]] = ashr i8 [[X:%.*]], 4
 ; CHECK-NEXT:    [[AND:%.*]] = and i8 [[ASHR]], -2
-; CHECK-NEXT:    [[CMP:%.*]] = icmp slt i8 [[AND]], -68
+; CHECK-NEXT:    [[CMP:%.*]] = icmp slt i8 [[AND]], -69
 ; CHECK-NEXT:    ret i1 [[CMP]]
 ;
   %ashr = ashr i8 %x, 4
@@ -5138,3 +5138,53 @@ entry:
   %cmp = icmp eq i8 %add2, %add1
   ret i1 %cmp
 }
+
+define i1 @tighten_icmp_using_known_bits_ugt(i16 %a) {
+; CHECK-LABEL: @tighten_icmp_using_known_bits_ugt(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[CMP:%.*]] = icmp ugt i16 [[A:%.*]], 15
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+entry:
+  %and_ = and i16 %a, 65532
+  %cmp = icmp ugt i16 %and_, 14
+  ret i1 %cmp
+}
+
+define i1 @tighten_icmp_using_known_bits_ult(i16 %a) {
+; CHECK-LABEL: @tighten_icmp_using_known_bits_ult(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[AND_:%.*]] = and i16 [[A:%.*]], -4
+; CHECK-NEXT:    [[CMP:%.*]] = icmp ult i16 [[AND_]], 17
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+entry:
+  %and_ = and i16 %a, 65532
+  %cmp = icmp ult i16 %and_, 18
+  ret i1 %cmp
+}
+
+define i1 @tighten_icmp_using_known_bits_sgt(i16 %a) {
+; CHECK-LABEL: @tighten_icmp_using_known_bits_sgt(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[CMP:%.*]] = icmp sgt i16 [[A:%.*]], -1
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+entry:
+  %and_ = and i16 %a, 65520
+  %cmp = icmp sgt i16 %and_, -15
+  ret i1 %cmp
+}
+
+define i1 @tighten_icmp_using_known_bits_slt(i16 %a) {
+; CHECK-LABEL: @tighten_icmp_using_known_bits_slt(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[AND_:%.*]] = and i16 [[A:%.*]], -4
+; CHECK-NEXT:    [[CMP:%.*]] = icmp slt i16 [[AND_]], -15
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+entry:
+  %and_ = and i16 %a, 65532
+  %cmp = icmp slt i16 %and_, -14
+  ret i1 %cmp
+}



More information about the llvm-commits mailing list