[llvm] [InstCombineCompare] Use known bits to insert assume intrinsics. (PR #96017)

Mikhail Gudim via llvm-commits llvm-commits at lists.llvm.org
Fri Jun 21 10:08:58 PDT 2024


https://github.com/mgudim updated https://github.com/llvm/llvm-project/pull/96017

>From 7826fe77db6a359df754ffbd6d95e8d64f74f490 Mon Sep 17 00:00:00 2001
From: Mikhail Gudim <mgudim at gmail.com>
Date: Tue, 18 Jun 2024 15:35:39 -0400
Subject: [PATCH 1/2] [InstCombineCompare] Use known bits to insert assume
 intrinsics.

If we have a compare instruction like `%cmp = icmp ult i16 %x, 14` and
it is known that lower 2 bits of `%x` are both zero, then we know that
if the result of compare is true, then `%x ult 13` and if
compare is false then in `%x uge 16`. In this MR we generate assume
intrinsics to express this knowledge.
---
 .../InstCombine/InstCombineCompares.cpp       | 121 ++++++++++
 .../Transforms/InstCombine/icmp-assume.ll     | 214 ++++++++++++++++++
 2 files changed, 335 insertions(+)
 create mode 100644 llvm/test/Transforms/InstCombine/icmp-assume.ll

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 89193f8ff94b6..11261200d225e 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -6333,6 +6333,125 @@ bool InstCombinerImpl::replacedSelectWithOperand(SelectInst *SI,
   return false;
 }
 
+static void computeClosestIntsSatisfyingKnownBits(
+    APInt Target, KnownBits &Known, unsigned BitWidth, bool IsSigned,
+    APInt &ClosestSmaller, APInt &ClosestBigger) {
+  int KnownZeroMaskLength = BitWidth - Known.Zero.countLeadingZeros();
+  if (KnownZeroMaskLength == 0)
+    return;
+
+  APInt PowOf2(BitWidth, 1 << KnownZeroMaskLength);
+  if (!IsSigned || Target.isNonNegative()) {
+    ClosestSmaller =
+        PowOf2 * APIntOps::RoundingUDiv(Target, PowOf2, APInt::Rounding::DOWN);
+    ClosestBigger =
+        PowOf2 * APIntOps::RoundingUDiv(Target, PowOf2, APInt::Rounding::UP);
+  } else {
+    ClosestSmaller =
+        PowOf2 * APIntOps::RoundingSDiv(Target, PowOf2, APInt::Rounding::UP);
+    ClosestBigger =
+        PowOf2 * APIntOps::RoundingSDiv(Target, PowOf2, APInt::Rounding::DOWN);
+  }
+}
+
+static void insertAssumeICmp(BasicBlock *BB, ICmpInst::Predicate Pred,
+                             Value *LHS, Value *RHS, LLVMContext &Ctx) {
+  IRBuilder<> Builder(Ctx);
+  Builder.SetInsertPoint(BB, BB->getFirstInsertionPt());
+  auto *Cmp = Builder.CreateICmp(Pred, LHS, RHS);
+  Builder.CreateAssumption(Cmp);
+  return;
+}
+
+static void tryToInsertAssumeBasedOnICmpAndKnownBits(ICmpInst &I,
+                                                     KnownBits Op0Known,
+                                                     KnownBits Op1Known,
+                                                     unsigned BitWidth) {
+  if (!BitWidth)
+    return;
+  if (!(Op1Known.isConstant() && Op0Known.Zero.isMask()))
+    return;
+
+  SmallVector<BasicBlock *> TBBs;
+  SmallVector<BasicBlock *> FBBs;
+  for (Use &U : I.uses()) {
+    Instruction *UI = cast<Instruction>(U.getUser());
+    if (BranchInst *BrUse = dyn_cast<BranchInst>(UI)) {
+      if (BrUse->isUnconditional())
+        continue;
+      TBBs.push_back(BrUse->getSuccessor(0));
+      FBBs.push_back(BrUse->getSuccessor(1));
+    }
+  }
+  if (TBBs.empty())
+    return;
+
+  ICmpInst::Predicate Pred = I.getPredicate();
+  APInt RHSConst = Op1Known.getConstant();
+
+  bool IsSigned = I.isSigned();
+  APInt ClosestSmaller(BitWidth, 0);
+  APInt ClosestBigger(BitWidth, 0);
+  computeClosestIntsSatisfyingKnownBits(RHSConst, Op0Known, BitWidth, IsSigned,
+                                        ClosestSmaller, ClosestBigger);
+
+  ICmpInst::Predicate AssumePredT = I.getPredicate();
+  ICmpInst::Predicate AssumePredF = ICmpInst::getInversePredicate(AssumePredT);
+  APInt AssumeRHSConstantT(BitWidth, 0);
+  APInt AssumeRHSConstantF(BitWidth, 0);
+  bool CanImproveT = false;
+  bool CanImproveF = false;
+
+  auto ltSignedOrUnsigned = [&](APInt LHS, APInt RHS, bool IsSigned) {
+    return IsSigned ? LHS.slt(RHS) : LHS.ult(RHS);
+  };
+  switch (Pred) {
+  default:
+    break;
+  case ICmpInst::ICMP_ULT:
+  case ICmpInst::ICMP_SLT: {
+    if (ltSignedOrUnsigned(ClosestSmaller, RHSConst - 1, IsSigned)) {
+      CanImproveT = true;
+      AssumeRHSConstantT = ClosestSmaller + 1;
+    }
+    if (ltSignedOrUnsigned(RHSConst, ClosestBigger, IsSigned)) {
+      CanImproveF = true;
+      AssumeRHSConstantF = ClosestBigger;
+    }
+    break;
+  }
+  case ICmpInst::ICMP_UGT:
+  case ICmpInst::ICMP_SGT: {
+    if (ltSignedOrUnsigned(RHSConst + 1, ClosestBigger, IsSigned)) {
+      CanImproveT = true;
+      AssumeRHSConstantT = ClosestBigger - 1;
+    }
+    if (ltSignedOrUnsigned(ClosestSmaller, RHSConst, IsSigned)) {
+      CanImproveF = true;
+      AssumeRHSConstantF = ClosestSmaller;
+    }
+    break;
+  }
+  }
+
+  Value *Op0 = I.getOperand(0);
+  Type *Ty = Op0->getType();
+  LLVMContext &Ctx = I.getContext();
+  if (CanImproveT) {
+    Constant *RHS = ConstantInt::get(Ty, AssumeRHSConstantT);
+    for (BasicBlock *TBB : TBBs) {
+      insertAssumeICmp(TBB, AssumePredT, Op0, RHS, Ctx);
+    }
+  }
+  if (CanImproveF) {
+    Constant *RHS = ConstantInt::get(Ty, AssumeRHSConstantF);
+    for (BasicBlock *FBB : FBBs) {
+      insertAssumeICmp(FBB, AssumePredF, Op0, RHS, Ctx);
+    }
+  }
+  return;
+}
+
 /// Try to fold the comparison based on range information we can get by checking
 /// whether bits are known to be zero or one in the inputs.
 Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) {
@@ -6590,6 +6709,8 @@ Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) {
        (Op0Known.One.isNegative() && Op1Known.One.isNegative())))
     return new ICmpInst(I.getUnsignedPredicate(), Op0, Op1);
 
+  tryToInsertAssumeBasedOnICmpAndKnownBits(I, Op0Known, Op1Known, BitWidth);
+
   return nullptr;
 }
 
diff --git a/llvm/test/Transforms/InstCombine/icmp-assume.ll b/llvm/test/Transforms/InstCombine/icmp-assume.ll
new file mode 100644
index 0000000000000..f07b272e14723
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/icmp-assume.ll
@@ -0,0 +1,214 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt < %s -passes=instcombine -S | FileCheck %s
+
+target datalayout = "e-p:64:64:64-p1:16:16:16-p2:32:32:32-p3:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-f64:64:64-v64:64:64-v128:128:128-a0:0:64-s0:64:64-f80:128:128-n8:16:32:64"
+
+define void @generate_assume_ult(i16 %a) {
+; CHECK-LABEL: define void @generate_assume_ult(
+; CHECK-SAME: i16 [[A:%.*]]) {
+; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    [[AND_:%.*]] = and i16 [[A]], -4
+; CHECK-NEXT:    [[CMP:%.*]] = icmp ult i16 [[AND_]], 14
+; CHECK-NEXT:    br i1 [[CMP]], label %[[T:.*]], label %[[F:.*]]
+; CHECK:       [[T]]:
+; CHECK-NEXT:    [[TMP0:%.*]] = icmp ult i16 [[AND_]], 13
+; CHECK-NEXT:    call void @llvm.assume(i1 [[TMP0]])
+; CHECK-NEXT:    ret void
+; CHECK:       [[F]]:
+; CHECK-NEXT:    [[TMP1:%.*]] = icmp uge i16 [[AND_]], 16
+; CHECK-NEXT:    call void @llvm.assume(i1 [[TMP1]])
+; CHECK-NEXT:    ret void
+;
+entry:
+  %and_ = and i16 %a, 65532
+  %cmp = icmp ult i16 %and_, 14
+  br i1 %cmp, label %t, label %f
+
+t:
+  ret void
+
+f:
+  ret void
+}
+
+define void @generate_assume_slt(i16 %a) {
+; CHECK-LABEL: define void @generate_assume_slt(
+; CHECK-SAME: i16 [[A:%.*]]) {
+; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    [[AND_:%.*]] = and i16 [[A]], -4
+; CHECK-NEXT:    [[CMP:%.*]] = icmp slt i16 [[AND_]], 14
+; CHECK-NEXT:    br i1 [[CMP]], label %[[T:.*]], label %[[F:.*]]
+; CHECK:       [[T]]:
+; CHECK-NEXT:    [[TMP0:%.*]] = icmp slt i16 [[AND_]], 13
+; CHECK-NEXT:    call void @llvm.assume(i1 [[TMP0]])
+; CHECK-NEXT:    ret void
+; CHECK:       [[F]]:
+; CHECK-NEXT:    [[TMP1:%.*]] = icmp sge i16 [[AND_]], 16
+; CHECK-NEXT:    call void @llvm.assume(i1 [[TMP1]])
+; CHECK-NEXT:    ret void
+;
+entry:
+  %and_ = and i16 %a, 65532
+  %cmp = icmp slt i16 %and_, 14
+  br i1 %cmp, label %t, label %f
+
+t:
+  ret void
+
+f:
+  ret void
+}
+
+define void @generate_assume_ugt(i16 %a) {
+; CHECK-LABEL: define void @generate_assume_ugt(
+; CHECK-SAME: i16 [[A:%.*]]) {
+; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    [[AND_:%.*]] = and i16 [[A]], -4
+; CHECK-NEXT:    [[CMP:%.*]] = icmp ugt i16 [[AND_]], 14
+; CHECK-NEXT:    br i1 [[CMP]], label %[[T:.*]], label %[[F:.*]]
+; CHECK:       [[T]]:
+; CHECK-NEXT:    [[TMP0:%.*]] = icmp ugt i16 [[AND_]], 15
+; CHECK-NEXT:    call void @llvm.assume(i1 [[TMP0]])
+; CHECK-NEXT:    ret void
+; CHECK:       [[F]]:
+; CHECK-NEXT:    [[TMP1:%.*]] = icmp ule i16 [[AND_]], 12
+; CHECK-NEXT:    call void @llvm.assume(i1 [[TMP1]])
+; CHECK-NEXT:    ret void
+;
+entry:
+  %and_ = and i16 %a, 65532
+  %cmp = icmp ugt i16 %and_, 14
+  br i1 %cmp, label %t, label %f
+
+t:
+  ret void
+
+f:
+  ret void
+}
+
+define void @generate_assume_sgt(i16 %a) {
+; CHECK-LABEL: define void @generate_assume_sgt(
+; CHECK-SAME: i16 [[A:%.*]]) {
+; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    [[AND_:%.*]] = and i16 [[A]], -4
+; CHECK-NEXT:    [[CMP:%.*]] = icmp sgt i16 [[AND_]], 14
+; CHECK-NEXT:    br i1 [[CMP]], label %[[T:.*]], label %[[F:.*]]
+; CHECK:       [[T]]:
+; CHECK-NEXT:    [[TMP0:%.*]] = icmp sgt i16 [[AND_]], 15
+; CHECK-NEXT:    call void @llvm.assume(i1 [[TMP0]])
+; CHECK-NEXT:    ret void
+; CHECK:       [[F]]:
+; CHECK-NEXT:    [[TMP1:%.*]] = icmp sle i16 [[AND_]], 12
+; CHECK-NEXT:    call void @llvm.assume(i1 [[TMP1]])
+; CHECK-NEXT:    ret void
+;
+entry:
+  %and_ = and i16 %a, 65532
+  %cmp = icmp sgt i16 %and_, 14
+  br i1 %cmp, label %t, label %f
+
+t:
+  ret void
+
+f:
+  ret void
+}
+
+define void @dont_generate_assume_t_ult(i16 %a) {
+; CHECK-LABEL: define void @dont_generate_assume_t_ult(
+; CHECK-SAME: i16 [[A:%.*]]) {
+; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    [[AND_:%.*]] = and i16 [[A]], -4
+; CHECK-NEXT:    [[CMP:%.*]] = icmp ult i16 [[AND_]], 13
+; CHECK-NEXT:    br i1 [[CMP]], label %[[T:.*]], label %[[F:.*]]
+; CHECK:       [[T]]:
+; CHECK-NEXT:    ret void
+; CHECK:       [[F]]:
+; CHECK-NEXT:    [[TMP0:%.*]] = icmp uge i16 [[AND_]], 16
+; CHECK-NEXT:    call void @llvm.assume(i1 [[TMP0]])
+; CHECK-NEXT:    ret void
+;
+entry:
+  %and_ = and i16 %a, 65532
+  %cmp = icmp ult i16 %and_, 13
+  br i1 %cmp, label %t, label %f
+
+t:
+  ret void
+
+f:
+  ret void
+}
+
+define void @dont_generate_assume_t_slt(i16 %a) {
+; CHECK-LABEL: define void @dont_generate_assume_t_slt(
+; CHECK-SAME: i16 [[A:%.*]]) {
+; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    [[AND_:%.*]] = and i16 [[A]], -4
+; CHECK-NEXT:    [[CMP:%.*]] = icmp slt i16 [[AND_]], 13
+; CHECK-NEXT:    br i1 [[CMP]], label %[[T:.*]], label %[[F:.*]]
+; CHECK:       [[T]]:
+; CHECK-NEXT:    ret void
+; CHECK:       [[F]]:
+; CHECK-NEXT:    [[TMP0:%.*]] = icmp sge i16 [[AND_]], 16
+; CHECK-NEXT:    call void @llvm.assume(i1 [[TMP0]])
+; CHECK-NEXT:    ret void
+;
+entry:
+  %and_ = and i16 %a, 65532
+  %cmp = icmp slt i16 %and_, 13
+  br i1 %cmp, label %t, label %f
+
+t:
+  ret void
+
+f:
+  ret void
+}
+
+define void @multiple_branches_ult(i16 %a) {
+; CHECK-LABEL: define void @multiple_branches_ult(
+; CHECK-SAME: i16 [[A:%.*]]) {
+; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    [[AND_:%.*]] = and i16 [[A]], -4
+; CHECK-NEXT:    [[CMP:%.*]] = icmp ult i16 [[AND_]], 14
+; CHECK-NEXT:    br i1 [[CMP]], label %[[T0:.*]], label %[[F0:.*]]
+; CHECK:       [[T0]]:
+; CHECK-NEXT:    [[TMP0:%.*]] = icmp ult i16 [[AND_]], 13
+; CHECK-NEXT:    call void @llvm.assume(i1 [[TMP0]])
+; CHECK-NEXT:    br label %[[BB:.*]]
+; CHECK:       [[F0]]:
+; CHECK-NEXT:    [[TMP1:%.*]] = icmp uge i16 [[AND_]], 16
+; CHECK-NEXT:    call void @llvm.assume(i1 [[TMP1]])
+; CHECK-NEXT:    br label %[[BB]]
+; CHECK:       [[BB]]:
+; CHECK-NEXT:    br i1 [[CMP]], label %[[T1:.*]], label %[[F1:.*]]
+; CHECK:       [[T1]]:
+; CHECK-NEXT:    [[TMP2:%.*]] = icmp ult i16 [[AND_]], 13
+; CHECK-NEXT:    call void @llvm.assume(i1 [[TMP2]])
+; CHECK-NEXT:    ret void
+; CHECK:       [[F1]]:
+; CHECK-NEXT:    [[TMP3:%.*]] = icmp uge i16 [[AND_]], 16
+; CHECK-NEXT:    call void @llvm.assume(i1 [[TMP3]])
+; CHECK-NEXT:    ret void
+;
+entry:
+  %and_ = and i16 %a, 65532
+  %cmp = icmp ult i16 %and_, 14
+  br i1 %cmp, label %t0, label %f0
+
+t0:
+  br label %bb
+
+f0:
+  br label %bb
+
+bb:
+  br i1 %cmp, label %t1, label %f1
+
+t1:
+  ret void
+f1:
+  ret void
+}

>From 2f77f23d5827ed946b059bdef0959aa894983e4f Mon Sep 17 00:00:00 2001
From: Mikhail Gudim <mgudim at gmail.com>
Date: Fri, 21 Jun 2024 13:08:24 -0400
Subject: [PATCH 2/2] Added CLI to disable insertion of assume intrinsics

---
 llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp | 7 ++++++-
 llvm/test/Transforms/InstCombine/icmp-assume.ll         | 2 +-
 2 files changed, 7 insertions(+), 2 deletions(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 11261200d225e..9ec13473753b1 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -38,6 +38,10 @@ using namespace PatternMatch;
 // How many times is a select replaced by one of its operands?
 STATISTIC(NumSel, "Number of select opts");
 
+static cl::opt<bool> DisableInsertAssumeICmp(
+    "instcombine-disable-insert-assume-icmp",
+    cl::init(true),
+    cl::desc("Disable insertion of assume intrinsics derevied from known bits and icmp"));
 
 /// Compute Result = In1+In2, returning true if the result overflowed for this
 /// type.
@@ -6709,7 +6713,8 @@ Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) {
        (Op0Known.One.isNegative() && Op1Known.One.isNegative())))
     return new ICmpInst(I.getUnsignedPredicate(), Op0, Op1);
 
-  tryToInsertAssumeBasedOnICmpAndKnownBits(I, Op0Known, Op1Known, BitWidth);
+  if (!DisableInsertAssumeICmp)
+    tryToInsertAssumeBasedOnICmpAndKnownBits(I, Op0Known, Op1Known, BitWidth);
 
   return nullptr;
 }
diff --git a/llvm/test/Transforms/InstCombine/icmp-assume.ll b/llvm/test/Transforms/InstCombine/icmp-assume.ll
index f07b272e14723..9c5adebb3dc15 100644
--- a/llvm/test/Transforms/InstCombine/icmp-assume.ll
+++ b/llvm/test/Transforms/InstCombine/icmp-assume.ll
@@ -1,5 +1,5 @@
 ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
-; RUN: opt < %s -passes=instcombine -S | FileCheck %s
+; RUN: opt < %s -passes=instcombine -instcombine-disable-insert-assume-icmp=false -S | FileCheck %s
 
 target datalayout = "e-p:64:64:64-p1:16:16:16-p2:32:32:32-p3:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-f64:64:64-v64:64:64-v128:128:128-a0:0:64-s0:64:64-f80:128:128-n8:16:32:64"
 



More information about the llvm-commits mailing list