[llvm] [InstCombineCompares] Try to "strengthen" compares based on known bits. (PR #79405)
Mikhail Gudim via llvm-commits
llvm-commits at lists.llvm.org
Wed Jan 24 20:32:38 PST 2024
https://github.com/mgudim created https://github.com/llvm/llvm-project/pull/79405
For example, replace `icmp ugt %x, 14` with `icmp ugt %x, 15` when it is known that the two least significant bits of `%x` is zero.
>From df2bae8822ebf9db11e9dcae16c9bb7c760476f1 Mon Sep 17 00:00:00 2001
From: Mikhail Gudim <mgudim at gmail.com>
Date: Tue, 16 Jan 2024 03:58:34 -0500
Subject: [PATCH] [InstCombineCompares] Try to "strengthen" compares based on
known bits.
For example, replace `icmp ugt %x, 14` with `icmp ugt %x, 15` when
it is known that the two least significant bits of `%x` is zero.
---
.../InstCombine/InstCombineCompares.cpp | 72 ++++++++++++++++
llvm/test/Transforms/InstCombine/icmp.ll | 82 +++++++++++++++----
2 files changed, 138 insertions(+), 16 deletions(-)
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 3875e59c3ede3b..30b421381de11c 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -6345,6 +6345,78 @@ Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) {
(Op0Known.One.isNegative() && Op1Known.One.isNegative())))
return new ICmpInst(I.getUnsignedPredicate(), Op0, Op1);
+ // Try to "strengthen" the RHS of compare based on known bits.
+ // For example, replace `icmp ugt %x, 14` with `icmp ugt %x, 15` when
+ // it is known that the two least significant bits of `%x` is zero.
+ if (Op1Known.isConstant() && Op0Known.Zero.isMask()) {
+ APInt RHSConst = Op1Known.getConstant();
+ ConstantRange Op0PredRange =
+ ConstantRange::makeExactICmpRegion(Pred, RHSConst);
+ int KnownZeroMaskLength = BitWidth - Op0Known.Zero.countLeadingZeros();
+ if (KnownZeroMaskLength > 0) {
+ APInt PowOf2(BitWidth, 1 << KnownZeroMaskLength);
+ APInt Op0PredMin(BitWidth, 0);
+ APInt Op0PredMax(BitWidth, 0);
+ APInt Op0MinRefinedByKnownBits(BitWidth, 0);
+ APInt Op0MaxRefinedByKnownBits(BitWidth, 0);
+ APInt NewLower(BitWidth, 0);
+ APInt NewUpper(BitWidth, 0);
+ bool ImprovedLower = false;
+ bool ImprovedUpper = false;
+ if (I.isSigned()) {
+ Op0PredMin = Op0PredRange.getSignedMin();
+ Op0PredMax = Op0PredRange.getSignedMax();
+ // Compute the smallest number satisfying the known-bits constrained
+ // which is at greater or equal Op0PredMin.
+ Op0MinRefinedByKnownBits =
+ PowOf2 *
+ APIntOps::RoundingSDiv(Op0PredMin, PowOf2, APInt::Rounding::UP);
+ // Compute the largest number satisfying the known-bits constrained
+ // which is at less or equal Op0PredMax.
+ Op0MaxRefinedByKnownBits =
+ PowOf2 *
+ APIntOps::RoundingSDiv(Op0PredMax, PowOf2, APInt::Rounding::DOWN);
+ NewLower = APIntOps::smax(Op0MinRefinedByKnownBits, Op0PredMin);
+ NewUpper = APIntOps::smin(Op0MaxRefinedByKnownBits, Op0PredMax);
+ ImprovedLower = NewLower.sgt(Op0PredMin);
+ ImprovedUpper = NewUpper.slt(Op0PredMax);
+ } else {
+ Op0PredMin = Op0PredRange.getUnsignedMin();
+ Op0PredMax = Op0PredRange.getUnsignedMax();
+ Op0MinRefinedByKnownBits =
+ PowOf2 *
+ APIntOps::RoundingUDiv(Op0PredMin, PowOf2, APInt::Rounding::UP);
+ Op0MaxRefinedByKnownBits =
+ PowOf2 *
+ APIntOps::RoundingUDiv(Op0PredMax, PowOf2, APInt::Rounding::DOWN);
+ NewLower = APIntOps::umax(Op0MinRefinedByKnownBits, Op0PredMin);
+ NewUpper = APIntOps::umin(Op0MaxRefinedByKnownBits, Op0PredMax);
+ ImprovedLower = NewLower.ugt(Op0PredMin);
+ ImprovedUpper = NewUpper.ult(Op0PredMax);
+ }
+
+ // Non-strict inequalities should have been canonicalized to strict ones
+ // by now.
+ switch (Pred) {
+ default:
+ break;
+ case ICmpInst::ICMP_ULT:
+ case ICmpInst::ICMP_SLT: {
+ if (ImprovedUpper)
+ return new ICmpInst(Pred, Op0,
+ ConstantInt::get(Op1->getType(), NewUpper + 1));
+ break;
+ }
+ case ICmpInst::ICMP_UGT:
+ case ICmpInst::ICMP_SGT: {
+ if (ImprovedLower)
+ return new ICmpInst(Pred, Op0,
+ ConstantInt::get(Op1->getType(), NewLower - 1));
+ break;
+ }
+ }
+ }
+ }
return nullptr;
}
diff --git a/llvm/test/Transforms/InstCombine/icmp.ll b/llvm/test/Transforms/InstCombine/icmp.ll
index 1f554c7b60256c..339d66e8c2e437 100644
--- a/llvm/test/Transforms/InstCombine/icmp.ll
+++ b/llvm/test/Transforms/InstCombine/icmp.ll
@@ -1445,8 +1445,8 @@ define <2 x i1> @test70vec(<2 x i32> %X) {
define i1 @icmp_sext16trunc(i32 %x) {
; CHECK-LABEL: @icmp_sext16trunc(
-; CHECK-NEXT: [[TMP1:%.*]] = trunc i32 [[X:%.*]] to i16
-; CHECK-NEXT: [[CMP:%.*]] = icmp slt i16 [[TMP1]], 36
+; CHECK-NEXT: [[SEXT1:%.*]] = shl i32 [[X:%.*]], 16
+; CHECK-NEXT: [[CMP:%.*]] = icmp slt i32 [[SEXT1]], 2293761
; CHECK-NEXT: ret i1 [[CMP]]
;
%trunc = trunc i32 %x to i16
@@ -1457,8 +1457,8 @@ define i1 @icmp_sext16trunc(i32 %x) {
define i1 @icmp_sext8trunc(i32 %x) {
; CHECK-LABEL: @icmp_sext8trunc(
-; CHECK-NEXT: [[TMP1:%.*]] = trunc i32 [[X:%.*]] to i8
-; CHECK-NEXT: [[CMP:%.*]] = icmp slt i8 [[TMP1]], 36
+; CHECK-NEXT: [[SEXT1:%.*]] = shl i32 [[X:%.*]], 24
+; CHECK-NEXT: [[CMP:%.*]] = icmp slt i32 [[SEXT1]], 587202561
; CHECK-NEXT: ret i1 [[CMP]]
;
%trunc = trunc i32 %x to i8
@@ -1470,8 +1470,8 @@ define i1 @icmp_sext8trunc(i32 %x) {
; Vectors should fold the same way.
define <2 x i1> @icmp_sext8trunc_vec(<2 x i32> %x) {
; CHECK-LABEL: @icmp_sext8trunc_vec(
-; CHECK-NEXT: [[TMP1:%.*]] = trunc <2 x i32> [[X:%.*]] to <2 x i8>
-; CHECK-NEXT: [[CMP:%.*]] = icmp slt <2 x i8> [[TMP1]], <i8 36, i8 36>
+; CHECK-NEXT: [[TMP1:%.*]] = shl <2 x i32> [[X:%.*]], <i32 24, i32 24>
+; CHECK-NEXT: [[CMP:%.*]] = icmp slt <2 x i32> [[TMP1]], <i32 587202561, i32 587202561>
; CHECK-NEXT: ret <2 x i1> [[CMP]]
;
%trunc = trunc <2 x i32> %x to <2 x i8>
@@ -1482,8 +1482,8 @@ define <2 x i1> @icmp_sext8trunc_vec(<2 x i32> %x) {
define i1 @icmp_shl16(i32 %x) {
; CHECK-LABEL: @icmp_shl16(
-; CHECK-NEXT: [[TMP1:%.*]] = trunc i32 [[X:%.*]] to i16
-; CHECK-NEXT: [[CMP:%.*]] = icmp slt i16 [[TMP1]], 36
+; CHECK-NEXT: [[SHL:%.*]] = shl i32 [[X:%.*]], 16
+; CHECK-NEXT: [[CMP:%.*]] = icmp slt i32 [[SHL]], 2293761
; CHECK-NEXT: ret i1 [[CMP]]
;
%shl = shl i32 %x, 16
@@ -1496,7 +1496,7 @@ define i1 @icmp_shl16(i32 %x) {
define i1 @icmp_shl17(i32 %x) {
; CHECK-LABEL: @icmp_shl17(
; CHECK-NEXT: [[SHL:%.*]] = shl i32 [[X:%.*]], 17
-; CHECK-NEXT: [[CMP:%.*]] = icmp slt i32 [[SHL]], 2359296
+; CHECK-NEXT: [[CMP:%.*]] = icmp slt i32 [[SHL]], 2228225
; CHECK-NEXT: ret i1 [[CMP]]
;
%shl = shl i32 %x, 17
@@ -1506,8 +1506,8 @@ define i1 @icmp_shl17(i32 %x) {
define <2 x i1> @icmp_shl16_vec(<2 x i32> %x) {
; CHECK-LABEL: @icmp_shl16_vec(
-; CHECK-NEXT: [[TMP1:%.*]] = trunc <2 x i32> [[X:%.*]] to <2 x i16>
-; CHECK-NEXT: [[CMP:%.*]] = icmp slt <2 x i16> [[TMP1]], <i16 36, i16 36>
+; CHECK-NEXT: [[SHL:%.*]] = shl <2 x i32> [[X:%.*]], <i32 16, i32 16>
+; CHECK-NEXT: [[CMP:%.*]] = icmp slt <2 x i32> [[SHL]], <i32 2293761, i32 2293761>
; CHECK-NEXT: ret <2 x i1> [[CMP]]
;
%shl = shl <2 x i32> %x, <i32 16, i32 16>
@@ -1517,8 +1517,8 @@ define <2 x i1> @icmp_shl16_vec(<2 x i32> %x) {
define i1 @icmp_shl24(i32 %x) {
; CHECK-LABEL: @icmp_shl24(
-; CHECK-NEXT: [[TMP1:%.*]] = trunc i32 [[X:%.*]] to i8
-; CHECK-NEXT: [[CMP:%.*]] = icmp slt i8 [[TMP1]], 36
+; CHECK-NEXT: [[SHL:%.*]] = shl i32 [[X:%.*]], 24
+; CHECK-NEXT: [[CMP:%.*]] = icmp slt i32 [[SHL]], 587202561
; CHECK-NEXT: ret i1 [[CMP]]
;
%shl = shl i32 %x, 24
@@ -2154,7 +2154,7 @@ define i1 @icmp_ashr_and_overshift(i8 %X) {
define i1 @icmp_and_ashr_neg_and_legal(i8 %x) {
; CHECK-LABEL: @icmp_and_ashr_neg_and_legal(
; CHECK-NEXT: [[TMP1:%.*]] = and i8 [[X:%.*]], -32
-; CHECK-NEXT: [[CMP:%.*]] = icmp slt i8 [[TMP1]], 16
+; CHECK-NEXT: [[CMP:%.*]] = icmp slt i8 [[TMP1]], 1
; CHECK-NEXT: ret i1 [[CMP]]
;
%ashr = ashr i8 %x, 4
@@ -2180,7 +2180,7 @@ define i1 @icmp_and_ashr_mixed_and_shiftout(i8 %x) {
define i1 @icmp_and_ashr_neg_cmp_slt_legal(i8 %x) {
; CHECK-LABEL: @icmp_and_ashr_neg_cmp_slt_legal(
; CHECK-NEXT: [[TMP1:%.*]] = and i8 [[X:%.*]], -32
-; CHECK-NEXT: [[CMP:%.*]] = icmp slt i8 [[TMP1]], -64
+; CHECK-NEXT: [[CMP:%.*]] = icmp slt i8 [[TMP1]], -95
; CHECK-NEXT: ret i1 [[CMP]]
;
%ashr = ashr i8 %x, 4
@@ -2194,7 +2194,7 @@ define i1 @icmp_and_ashr_neg_cmp_slt_shiftout(i8 %x) {
; CHECK-LABEL: @icmp_and_ashr_neg_cmp_slt_shiftout(
; CHECK-NEXT: [[ASHR:%.*]] = ashr i8 [[X:%.*]], 4
; CHECK-NEXT: [[AND:%.*]] = and i8 [[ASHR]], -2
-; CHECK-NEXT: [[CMP:%.*]] = icmp slt i8 [[AND]], -68
+; CHECK-NEXT: [[CMP:%.*]] = icmp slt i8 [[AND]], -69
; CHECK-NEXT: ret i1 [[CMP]]
;
%ashr = ashr i8 %x, 4
@@ -5138,3 +5138,53 @@ entry:
%cmp = icmp eq i8 %add2, %add1
ret i1 %cmp
}
+
+define i1 @tighten_icmp_using_known_bits_ugt(i16 %a) {
+; CHECK-LABEL: @tighten_icmp_using_known_bits_ugt(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: [[CMP:%.*]] = icmp ugt i16 [[A:%.*]], 15
+; CHECK-NEXT: ret i1 [[CMP]]
+;
+entry:
+ %and_ = and i16 %a, 65532
+ %cmp = icmp ugt i16 %and_, 14
+ ret i1 %cmp
+}
+
+define i1 @tighten_icmp_using_known_bits_ult(i16 %a) {
+; CHECK-LABEL: @tighten_icmp_using_known_bits_ult(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: [[AND_:%.*]] = and i16 [[A:%.*]], -4
+; CHECK-NEXT: [[CMP:%.*]] = icmp ult i16 [[AND_]], 17
+; CHECK-NEXT: ret i1 [[CMP]]
+;
+entry:
+ %and_ = and i16 %a, 65532
+ %cmp = icmp ult i16 %and_, 18
+ ret i1 %cmp
+}
+
+define i1 @tighten_icmp_using_known_bits_sgt(i16 %a) {
+; CHECK-LABEL: @tighten_icmp_using_known_bits_sgt(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: [[CMP:%.*]] = icmp sgt i16 [[A:%.*]], -1
+; CHECK-NEXT: ret i1 [[CMP]]
+;
+entry:
+ %and_ = and i16 %a, 65520
+ %cmp = icmp sgt i16 %and_, -15
+ ret i1 %cmp
+}
+
+define i1 @tighten_icmp_using_known_bits_slt(i16 %a) {
+; CHECK-LABEL: @tighten_icmp_using_known_bits_slt(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: [[AND_:%.*]] = and i16 [[A:%.*]], -4
+; CHECK-NEXT: [[CMP:%.*]] = icmp slt i16 [[AND_]], -15
+; CHECK-NEXT: ret i1 [[CMP]]
+;
+entry:
+ %and_ = and i16 %a, 65532
+ %cmp = icmp slt i16 %and_, -14
+ ret i1 %cmp
+}
More information about the llvm-commits
mailing list