[llvm] [InstCombineCompare] Use known bits to insert assume intrinsics. (PR #96017)

via llvm-commits llvm-commits at lists.llvm.org
Fri Jun 21 15:22:59 PDT 2024


================
@@ -6333,6 +6337,125 @@ bool InstCombinerImpl::replacedSelectWithOperand(SelectInst *SI,
   return false;
 }
 
+static void computeClosestIntsSatisfyingKnownBits(
+    APInt Target, KnownBits &Known, unsigned BitWidth, bool IsSigned,
+    APInt &ClosestSmaller, APInt &ClosestBigger) {
+  int KnownZeroMaskLength = BitWidth - Known.Zero.countLeadingZeros();
+  if (KnownZeroMaskLength == 0)
+    return;
+
+  APInt PowOf2(BitWidth, 1 << KnownZeroMaskLength);
+  if (!IsSigned || Target.isNonNegative()) {
+    ClosestSmaller =
+        PowOf2 * APIntOps::RoundingUDiv(Target, PowOf2, APInt::Rounding::DOWN);
+    ClosestBigger =
+        PowOf2 * APIntOps::RoundingUDiv(Target, PowOf2, APInt::Rounding::UP);
+  } else {
+    ClosestSmaller =
+        PowOf2 * APIntOps::RoundingSDiv(Target, PowOf2, APInt::Rounding::UP);
+    ClosestBigger =
+        PowOf2 * APIntOps::RoundingSDiv(Target, PowOf2, APInt::Rounding::DOWN);
+  }
+}
+
+static void insertAssumeICmp(BasicBlock *BB, ICmpInst::Predicate Pred,
+                             Value *LHS, Value *RHS, LLVMContext &Ctx) {
+  IRBuilder<> Builder(Ctx);
+  Builder.SetInsertPoint(BB, BB->getFirstInsertionPt());
+  auto *Cmp = Builder.CreateICmp(Pred, LHS, RHS);
+  Builder.CreateAssumption(Cmp);
+  return;
+}
+
+static void tryToInsertAssumeBasedOnICmpAndKnownBits(ICmpInst &I,
+                                                     KnownBits Op0Known,
+                                                     KnownBits Op1Known,
+                                                     unsigned BitWidth) {
+  if (!BitWidth)
+    return;
+  if (!(Op1Known.isConstant() && Op0Known.Zero.isMask()))
+    return;
+
+  SmallVector<BasicBlock *> TBBs;
+  SmallVector<BasicBlock *> FBBs;
+  for (Use &U : I.uses()) {
+    Instruction *UI = cast<Instruction>(U.getUser());
+    if (BranchInst *BrUse = dyn_cast<BranchInst>(UI)) {
+      if (BrUse->isUnconditional())
+        continue;
+      TBBs.push_back(BrUse->getSuccessor(0));
+      FBBs.push_back(BrUse->getSuccessor(1));
+    }
+  }
+  if (TBBs.empty())
+    return;
+
+  ICmpInst::Predicate Pred = I.getPredicate();
+  APInt RHSConst = Op1Known.getConstant();
+
+  bool IsSigned = I.isSigned();
+  APInt ClosestSmaller(BitWidth, 0);
+  APInt ClosestBigger(BitWidth, 0);
+  computeClosestIntsSatisfyingKnownBits(RHSConst, Op0Known, BitWidth, IsSigned,
+                                        ClosestSmaller, ClosestBigger);
+
+  ICmpInst::Predicate AssumePredT = I.getPredicate();
+  ICmpInst::Predicate AssumePredF = ICmpInst::getInversePredicate(AssumePredT);
+  APInt AssumeRHSConstantT(BitWidth, 0);
+  APInt AssumeRHSConstantF(BitWidth, 0);
+  bool CanImproveT = false;
+  bool CanImproveF = false;
+
+  auto ltSignedOrUnsigned = [&](APInt LHS, APInt RHS, bool IsSigned) {
+    return IsSigned ? LHS.slt(RHS) : LHS.ult(RHS);
+  };
+  switch (Pred) {
+  default:
+    break;
+  case ICmpInst::ICMP_ULT:
+  case ICmpInst::ICMP_SLT: {
+    if (ltSignedOrUnsigned(ClosestSmaller, RHSConst - 1, IsSigned)) {
+      CanImproveT = true;
+      AssumeRHSConstantT = ClosestSmaller + 1;
+    }
+    if (ltSignedOrUnsigned(RHSConst, ClosestBigger, IsSigned)) {
+      CanImproveF = true;
+      AssumeRHSConstantF = ClosestBigger;
+    }
+    break;
+  }
+  case ICmpInst::ICMP_UGT:
+  case ICmpInst::ICMP_SGT: {
+    if (ltSignedOrUnsigned(RHSConst + 1, ClosestBigger, IsSigned)) {
+      CanImproveT = true;
+      AssumeRHSConstantT = ClosestBigger - 1;
+    }
+    if (ltSignedOrUnsigned(ClosestSmaller, RHSConst, IsSigned)) {
+      CanImproveF = true;
+      AssumeRHSConstantF = ClosestSmaller;
+    }
+    break;
+  }
+  }
+
+  Value *Op0 = I.getOperand(0);
+  Type *Ty = Op0->getType();
+  LLVMContext &Ctx = I.getContext();
+  if (CanImproveT) {
+    Constant *RHS = ConstantInt::get(Ty, AssumeRHSConstantT);
+    for (BasicBlock *TBB : TBBs) {
+      insertAssumeICmp(TBB, AssumePredT, Op0, RHS, Ctx);
+    }
+  }
+  if (CanImproveF) {
+    Constant *RHS = ConstantInt::get(Ty, AssumeRHSConstantF);
+    for (BasicBlock *FBB : FBBs) {
+      insertAssumeICmp(FBB, AssumePredF, Op0, RHS, Ctx);
+    }
+  }
+  return;
+}
----------------
goldsteinn wrote:

This code is in desperate need of some comments.

https://github.com/llvm/llvm-project/pull/96017


More information about the llvm-commits mailing list