[llvm] [InstCombineCompare] Use known bits to insert assume intrinsics. (PR #96017)
via llvm-commits
llvm-commits at lists.llvm.org
Fri Jun 21 15:22:59 PDT 2024
================
@@ -6333,6 +6337,125 @@ bool InstCombinerImpl::replacedSelectWithOperand(SelectInst *SI,
return false;
}
+static void computeClosestIntsSatisfyingKnownBits(
+ APInt Target, KnownBits &Known, unsigned BitWidth, bool IsSigned,
+ APInt &ClosestSmaller, APInt &ClosestBigger) {
+ int KnownZeroMaskLength = BitWidth - Known.Zero.countLeadingZeros();
+ if (KnownZeroMaskLength == 0)
+ return;
+
+ APInt PowOf2(BitWidth, 1 << KnownZeroMaskLength);
+ if (!IsSigned || Target.isNonNegative()) {
+ ClosestSmaller =
+ PowOf2 * APIntOps::RoundingUDiv(Target, PowOf2, APInt::Rounding::DOWN);
+ ClosestBigger =
+ PowOf2 * APIntOps::RoundingUDiv(Target, PowOf2, APInt::Rounding::UP);
+ } else {
+ ClosestSmaller =
+ PowOf2 * APIntOps::RoundingSDiv(Target, PowOf2, APInt::Rounding::UP);
+ ClosestBigger =
+ PowOf2 * APIntOps::RoundingSDiv(Target, PowOf2, APInt::Rounding::DOWN);
+ }
+}
+
+static void insertAssumeICmp(BasicBlock *BB, ICmpInst::Predicate Pred,
+ Value *LHS, Value *RHS, LLVMContext &Ctx) {
+ IRBuilder<> Builder(Ctx);
+ Builder.SetInsertPoint(BB, BB->getFirstInsertionPt());
+ auto *Cmp = Builder.CreateICmp(Pred, LHS, RHS);
+ Builder.CreateAssumption(Cmp);
+ return;
+}
+
+static void tryToInsertAssumeBasedOnICmpAndKnownBits(ICmpInst &I,
+ KnownBits Op0Known,
+ KnownBits Op1Known,
+ unsigned BitWidth) {
+ if (!BitWidth)
+ return;
+ if (!(Op1Known.isConstant() && Op0Known.Zero.isMask()))
+ return;
+
+ SmallVector<BasicBlock *> TBBs;
+ SmallVector<BasicBlock *> FBBs;
+ for (Use &U : I.uses()) {
+ Instruction *UI = cast<Instruction>(U.getUser());
+ if (BranchInst *BrUse = dyn_cast<BranchInst>(UI)) {
+ if (BrUse->isUnconditional())
+ continue;
+ TBBs.push_back(BrUse->getSuccessor(0));
+ FBBs.push_back(BrUse->getSuccessor(1));
+ }
+ }
+ if (TBBs.empty())
+ return;
+
+ ICmpInst::Predicate Pred = I.getPredicate();
+ APInt RHSConst = Op1Known.getConstant();
+
+ bool IsSigned = I.isSigned();
+ APInt ClosestSmaller(BitWidth, 0);
+ APInt ClosestBigger(BitWidth, 0);
+ computeClosestIntsSatisfyingKnownBits(RHSConst, Op0Known, BitWidth, IsSigned,
+ ClosestSmaller, ClosestBigger);
+
+ ICmpInst::Predicate AssumePredT = I.getPredicate();
+ ICmpInst::Predicate AssumePredF = ICmpInst::getInversePredicate(AssumePredT);
+ APInt AssumeRHSConstantT(BitWidth, 0);
+ APInt AssumeRHSConstantF(BitWidth, 0);
+ bool CanImproveT = false;
+ bool CanImproveF = false;
+
+ auto ltSignedOrUnsigned = [&](APInt LHS, APInt RHS, bool IsSigned) {
+ return IsSigned ? LHS.slt(RHS) : LHS.ult(RHS);
+ };
+ switch (Pred) {
+ default:
+ break;
+ case ICmpInst::ICMP_ULT:
+ case ICmpInst::ICMP_SLT: {
+ if (ltSignedOrUnsigned(ClosestSmaller, RHSConst - 1, IsSigned)) {
+ CanImproveT = true;
+ AssumeRHSConstantT = ClosestSmaller + 1;
+ }
+ if (ltSignedOrUnsigned(RHSConst, ClosestBigger, IsSigned)) {
+ CanImproveF = true;
+ AssumeRHSConstantF = ClosestBigger;
+ }
+ break;
+ }
+ case ICmpInst::ICMP_UGT:
+ case ICmpInst::ICMP_SGT: {
+ if (ltSignedOrUnsigned(RHSConst + 1, ClosestBigger, IsSigned)) {
+ CanImproveT = true;
+ AssumeRHSConstantT = ClosestBigger - 1;
+ }
+ if (ltSignedOrUnsigned(ClosestSmaller, RHSConst, IsSigned)) {
+ CanImproveF = true;
+ AssumeRHSConstantF = ClosestSmaller;
+ }
+ break;
+ }
+ }
+
+ Value *Op0 = I.getOperand(0);
+ Type *Ty = Op0->getType();
+ LLVMContext &Ctx = I.getContext();
+ if (CanImproveT) {
+ Constant *RHS = ConstantInt::get(Ty, AssumeRHSConstantT);
+ for (BasicBlock *TBB : TBBs) {
+ insertAssumeICmp(TBB, AssumePredT, Op0, RHS, Ctx);
+ }
+ }
+ if (CanImproveF) {
+ Constant *RHS = ConstantInt::get(Ty, AssumeRHSConstantF);
+ for (BasicBlock *FBB : FBBs) {
+ insertAssumeICmp(FBB, AssumePredF, Op0, RHS, Ctx);
+ }
+ }
+ return;
+}
----------------
goldsteinn wrote:
This code is in desperate need of some comments.
https://github.com/llvm/llvm-project/pull/96017
More information about the llvm-commits
mailing list