[llvm] [GlobalIsel] Combine logic of icmps (PR #77855)
Matt Arsenault via llvm-commits
llvm-commits at lists.llvm.org
Thu Jan 11 20:16:39 PST 2024
================
@@ -6643,3 +6644,178 @@ bool CombinerHelper::matchSelect(MachineInstr &MI, BuildFnTy &MatchInfo) {
return false;
}
+
+/// Fold (icmp Pred1 V1, C1) && (icmp Pred2 V2, C2)
+/// or (icmp Pred1 V1, C1) || (icmp Pred2 V2, C2)
+/// into a single comparison using range-based reasoning.
+/// see InstCombinerImpl::foldAndOrOfICmpsUsingRanges.
+bool CombinerHelper::tryFoldAndOrOrICmpsUsingRanges(GLogicalBinOp *Logic,
+ BuildFnTy &MatchInfo) {
+ assert(Logic->getOpcode() != TargetOpcode::G_XOR && "unexpected xor");
+ bool IsAnd = Logic->getOpcode() == TargetOpcode::G_AND;
+ Register DstReg = Logic->getReg(0);
+ Register LHS = Logic->getLHSReg();
+ Register RHS = Logic->getRHSReg();
+ unsigned Flags = Logic->getFlags();
+
+ // We need an G_ICMP on the LHS register.
+ GICmp *Cmp1 = getOpcodeDef<GICmp>(LHS, MRI);
+ if (!Cmp1)
+ return false;
+
+ // We need an G_ICMP on the RHS register.
+ GICmp *Cmp2 = getOpcodeDef<GICmp>(RHS, MRI);
+ if (!Cmp2)
+ return false;
+
+ APInt C1;
+ APInt C2;
+ std::optional<ValueAndVReg> MaybeC1 =
+ getIConstantVRegValWithLookThrough(Cmp1->getRHSReg(), MRI);
+ if (!MaybeC1)
+ return false;
+ C1 = MaybeC1->Value;
+
+ std::optional<ValueAndVReg> MaybeC2 =
+ getIConstantVRegValWithLookThrough(Cmp2->getRHSReg(), MRI);
+ if (!MaybeC2)
+ return false;
+ C2 = MaybeC2->Value;
+
+ Register R1 = Cmp1->getLHSReg();
+ Register R2 = Cmp2->getLHSReg();
+ CmpInst::Predicate Pred1 = Cmp1->getCond();
+ CmpInst::Predicate Pred2 = Cmp2->getCond();
+ LLT CmpTy = MRI.getType(Cmp1->getReg(0));
+ LLT CmpOperandTy = MRI.getType(R1);
+
+ // We build ands, adds, and constants of type CmpOperandTy.
+ // They must be legal to build.
+ if (!isLegalOrBeforeLegalizer({TargetOpcode::G_AND, CmpOperandTy}) ||
+ !isLegalOrBeforeLegalizer({TargetOpcode::G_ADD, CmpOperandTy}) ||
+ !isLegalOrBeforeLegalizer({TargetOpcode::G_CONSTANT, CmpOperandTy}))
+ return false;
+
+ // Look through add of a constant offset on R1, R2, or both operands. This
+ // allows us to interpret the R + C' < C'' range idiom into a proper range.
+ std::optional<APInt> Offset1;
+ std::optional<APInt> Offset2;
+ if (R1 != R2) {
+ if (GAdd *Add = getOpcodeDef<GAdd>(R1, MRI)) {
+ std::optional<ValueAndVReg> MaybeOffset1 =
+ getIConstantVRegValWithLookThrough(Add->getRHSReg(), MRI);
+ if (MaybeOffset1) {
+ R1 = Add->getLHSReg();
+ Offset1 = MaybeOffset1->Value;
+ }
+ }
+ if (GAdd *Add = getOpcodeDef<GAdd>(R2, MRI)) {
+ std::optional<ValueAndVReg> MaybeOffset2 =
+ getIConstantVRegValWithLookThrough(Add->getRHSReg(), MRI);
+ if (MaybeOffset2) {
+ R2 = Add->getLHSReg();
+ Offset2 = MaybeOffset2->Value;
+ }
+ }
+ }
+
+ if (R1 != R2)
+ return false;
+
+ // We calculate the icmp ranges including maybe offsets.
+ ConstantRange CR1 = ConstantRange::makeExactICmpRegion(
+ IsAnd ? ICmpInst::getInversePredicate(Pred1) : Pred1, C1);
+ if (Offset1)
+ CR1 = CR1.subtract(*Offset1);
+
+ ConstantRange CR2 = ConstantRange::makeExactICmpRegion(
+ IsAnd ? ICmpInst::getInversePredicate(Pred2) : Pred2, C2);
+ if (Offset2)
+ CR2 = CR2.subtract(*Offset2);
+
+ bool CreateMask = false;
+ APInt LowerDiff;
+ std::optional<ConstantRange> CR = CR1.exactUnionWith(CR2);
+ if (!CR) {
+ // We want to fold the icmps.
+ if (!MRI.hasOneNonDBGUse(Cmp1->getReg(0)) ||
+ !MRI.hasOneNonDBGUse(Cmp2->getReg(0)) || CR1.isWrappedSet() ||
+ CR2.isWrappedSet())
+ return false;
+
+ // Check whether we have equal-size ranges that only differ by one bit.
+ // In that case we can apply a mask to map one range onto the other.
+ LowerDiff = CR1.getLower() ^ CR2.getLower();
+ APInt UpperDiff = (CR1.getUpper() - 1) ^ (CR2.getUpper() - 1);
+ APInt CR1Size = CR1.getUpper() - CR1.getLower();
+ if (!LowerDiff.isPowerOf2() || LowerDiff != UpperDiff ||
+ CR1Size != CR2.getUpper() - CR2.getLower())
+ return false;
+
+ CR = CR1.getLower().ult(CR2.getLower()) ? CR1 : CR2;
+ CreateMask = true;
+ }
+
+ if (IsAnd)
+ CR = CR->inverse();
+
+ CmpInst::Predicate NewPred;
+ APInt NewC, Offset;
+ CR->getEquivalentICmp(NewPred, NewC, Offset);
+
+ // We take the result type of one of the original icmps, CmpTy, for
+ // the to be build icmp. The operand type, CmpOperandTy, is used for
+ // the other instructions and constants to be build. The types of
+ // the parameters and output are the same for add and and. CmpTy
+ // and the type of DstReg might differ. That is why we zext or trunc
+ // the icmp into the destination register.
+
+ MatchInfo = [=](MachineIRBuilder &B) {
+ if (CreateMask && Offset != 0) {
+ auto TildeLowerDiff = B.buildConstant(CmpOperandTy, ~LowerDiff);
+ auto And = B.buildAnd(CmpOperandTy, R1, TildeLowerDiff); // the mask.
+ auto OffsetC = B.buildConstant(CmpOperandTy, Offset);
+ auto Add = B.buildAdd(CmpOperandTy, And, OffsetC, Flags);
+ auto NewCon = B.buildConstant(CmpOperandTy, NewC);
+ auto ICmp = B.buildICmp(NewPred, CmpTy, Add, NewCon);
+ B.buildZExtOrTrunc(DstReg, ICmp);
+ } else if (CreateMask && Offset == 0) {
+ auto TildeLowerDiff = B.buildConstant(CmpOperandTy, ~LowerDiff);
+ auto And = B.buildAnd(CmpOperandTy, R1, TildeLowerDiff); // the mask.
+ auto NewCon = B.buildConstant(CmpOperandTy, NewC);
+ auto ICmp = B.buildICmp(NewPred, CmpTy, And, NewCon);
+ B.buildZExtOrTrunc(DstReg, ICmp);
+ } else if (!CreateMask && Offset != 0) {
+ auto OffsetC = B.buildConstant(CmpOperandTy, Offset);
+ auto Add = B.buildAdd(CmpOperandTy, R1, OffsetC, Flags);
+ auto NewCon = B.buildConstant(CmpOperandTy, NewC);
+ auto ICmp = B.buildICmp(NewPred, CmpTy, Add, NewCon);
+ B.buildZExtOrTrunc(DstReg, ICmp);
+ } else if (!CreateMask && Offset == 0) {
+ auto NewCon = B.buildConstant(CmpOperandTy, NewC);
+ auto ICmp = B.buildICmp(NewPred, CmpTy, R1, NewCon);
+ B.buildZExtOrTrunc(DstReg, ICmp);
+ } else {
+ assert(false && "unexpected configuration of CreateMask and Offset");
+ }
+ };
+ return true;
+}
+
+bool CombinerHelper::matchAnd(MachineInstr &MI, BuildFnTy &MatchInfo) {
+ GAnd *And = cast<GAnd>(&MI);
+
+ if (tryFoldAndOrOrICmpsUsingRanges(And, MatchInfo))
+ return true;
+
+ return false;
+}
+
+bool CombinerHelper::matchOr(MachineInstr &MI, BuildFnTy &MatchInfo) {
+ GOr *Or = cast<GOr>(&MI);
+
+ if (tryFoldAndOrOrICmpsUsingRanges(Or, MatchInfo))
----------------
arsenm wrote:
Ditto
https://github.com/llvm/llvm-project/pull/77855
More information about the llvm-commits
mailing list