[llvm] [GlobalIsel] Combine logic of icmps (PR #77855)

Matt Arsenault via llvm-commits llvm-commits at lists.llvm.org
Thu Jan 11 20:16:39 PST 2024


================
@@ -6643,3 +6644,178 @@ bool CombinerHelper::matchSelect(MachineInstr &MI, BuildFnTy &MatchInfo) {
 
   return false;
 }
+
+/// Fold (icmp Pred1 V1, C1) && (icmp Pred2 V2, C2)
+/// or   (icmp Pred1 V1, C1) || (icmp Pred2 V2, C2)
+/// into a single comparison using range-based reasoning.
+/// see InstCombinerImpl::foldAndOrOfICmpsUsingRanges.
+bool CombinerHelper::tryFoldAndOrOrICmpsUsingRanges(GLogicalBinOp *Logic,
+                                                    BuildFnTy &MatchInfo) {
+  assert(Logic->getOpcode() != TargetOpcode::G_XOR && "unexpected xor");
+  bool IsAnd = Logic->getOpcode() == TargetOpcode::G_AND;
+  Register DstReg = Logic->getReg(0);
+  Register LHS = Logic->getLHSReg();
+  Register RHS = Logic->getRHSReg();
+  unsigned Flags = Logic->getFlags();
+
+  // We need an G_ICMP on the LHS register.
+  GICmp *Cmp1 = getOpcodeDef<GICmp>(LHS, MRI);
+  if (!Cmp1)
+    return false;
+
+  // We need an G_ICMP on the RHS register.
+  GICmp *Cmp2 = getOpcodeDef<GICmp>(RHS, MRI);
+  if (!Cmp2)
+    return false;
+
+  APInt C1;
+  APInt C2;
+  std::optional<ValueAndVReg> MaybeC1 =
+      getIConstantVRegValWithLookThrough(Cmp1->getRHSReg(), MRI);
+  if (!MaybeC1)
+    return false;
+  C1 = MaybeC1->Value;
+
+  std::optional<ValueAndVReg> MaybeC2 =
+      getIConstantVRegValWithLookThrough(Cmp2->getRHSReg(), MRI);
+  if (!MaybeC2)
+    return false;
+  C2 = MaybeC2->Value;
+
+  Register R1 = Cmp1->getLHSReg();
+  Register R2 = Cmp2->getLHSReg();
+  CmpInst::Predicate Pred1 = Cmp1->getCond();
+  CmpInst::Predicate Pred2 = Cmp2->getCond();
+  LLT CmpTy = MRI.getType(Cmp1->getReg(0));
+  LLT CmpOperandTy = MRI.getType(R1);
+
+  // We build ands, adds, and constants of type CmpOperandTy.
+  // They must be legal to build.
+  if (!isLegalOrBeforeLegalizer({TargetOpcode::G_AND, CmpOperandTy}) ||
+      !isLegalOrBeforeLegalizer({TargetOpcode::G_ADD, CmpOperandTy}) ||
+      !isLegalOrBeforeLegalizer({TargetOpcode::G_CONSTANT, CmpOperandTy}))
+    return false;
+
+  // Look through add of a constant offset on R1, R2, or both operands. This
+  // allows us to interpret the R + C' < C'' range idiom into a proper range.
+  std::optional<APInt> Offset1;
+  std::optional<APInt> Offset2;
+  if (R1 != R2) {
+    if (GAdd *Add = getOpcodeDef<GAdd>(R1, MRI)) {
+      std::optional<ValueAndVReg> MaybeOffset1 =
+          getIConstantVRegValWithLookThrough(Add->getRHSReg(), MRI);
+      if (MaybeOffset1) {
+        R1 = Add->getLHSReg();
+        Offset1 = MaybeOffset1->Value;
+      }
+    }
+    if (GAdd *Add = getOpcodeDef<GAdd>(R2, MRI)) {
+      std::optional<ValueAndVReg> MaybeOffset2 =
+          getIConstantVRegValWithLookThrough(Add->getRHSReg(), MRI);
+      if (MaybeOffset2) {
+        R2 = Add->getLHSReg();
+        Offset2 = MaybeOffset2->Value;
+      }
+    }
+  }
+
+  if (R1 != R2)
+    return false;
+
+  // We calculate the icmp ranges including maybe offsets.
+  ConstantRange CR1 = ConstantRange::makeExactICmpRegion(
+      IsAnd ? ICmpInst::getInversePredicate(Pred1) : Pred1, C1);
+  if (Offset1)
+    CR1 = CR1.subtract(*Offset1);
+
+  ConstantRange CR2 = ConstantRange::makeExactICmpRegion(
+      IsAnd ? ICmpInst::getInversePredicate(Pred2) : Pred2, C2);
+  if (Offset2)
+    CR2 = CR2.subtract(*Offset2);
+
+  bool CreateMask = false;
+  APInt LowerDiff;
+  std::optional<ConstantRange> CR = CR1.exactUnionWith(CR2);
+  if (!CR) {
+    // We want to fold the icmps.
+    if (!MRI.hasOneNonDBGUse(Cmp1->getReg(0)) ||
+        !MRI.hasOneNonDBGUse(Cmp2->getReg(0)) || CR1.isWrappedSet() ||
+        CR2.isWrappedSet())
+      return false;
+
+    // Check whether we have equal-size ranges that only differ by one bit.
+    // In that case we can apply a mask to map one range onto the other.
+    LowerDiff = CR1.getLower() ^ CR2.getLower();
+    APInt UpperDiff = (CR1.getUpper() - 1) ^ (CR2.getUpper() - 1);
+    APInt CR1Size = CR1.getUpper() - CR1.getLower();
+    if (!LowerDiff.isPowerOf2() || LowerDiff != UpperDiff ||
+        CR1Size != CR2.getUpper() - CR2.getLower())
+      return false;
+
+    CR = CR1.getLower().ult(CR2.getLower()) ? CR1 : CR2;
+    CreateMask = true;
+  }
+
+  if (IsAnd)
+    CR = CR->inverse();
+
+  CmpInst::Predicate NewPred;
+  APInt NewC, Offset;
+  CR->getEquivalentICmp(NewPred, NewC, Offset);
+
+  // We take the result type of one of the original icmps, CmpTy, for
+  // the to be build icmp. The operand type, CmpOperandTy, is used for
+  // the other instructions and constants to be build. The types of
+  // the parameters and output are the same for add and and.  CmpTy
+  // and the type of DstReg might differ. That is why we zext or trunc
+  // the icmp into the destination register.
+
+  MatchInfo = [=](MachineIRBuilder &B) {
+    if (CreateMask && Offset != 0) {
+      auto TildeLowerDiff = B.buildConstant(CmpOperandTy, ~LowerDiff);
+      auto And = B.buildAnd(CmpOperandTy, R1, TildeLowerDiff); // the mask.
+      auto OffsetC = B.buildConstant(CmpOperandTy, Offset);
+      auto Add = B.buildAdd(CmpOperandTy, And, OffsetC, Flags);
+      auto NewCon = B.buildConstant(CmpOperandTy, NewC);
+      auto ICmp = B.buildICmp(NewPred, CmpTy, Add, NewCon);
+      B.buildZExtOrTrunc(DstReg, ICmp);
+    } else if (CreateMask && Offset == 0) {
+      auto TildeLowerDiff = B.buildConstant(CmpOperandTy, ~LowerDiff);
+      auto And = B.buildAnd(CmpOperandTy, R1, TildeLowerDiff); // the mask.
+      auto NewCon = B.buildConstant(CmpOperandTy, NewC);
+      auto ICmp = B.buildICmp(NewPred, CmpTy, And, NewCon);
+      B.buildZExtOrTrunc(DstReg, ICmp);
+    } else if (!CreateMask && Offset != 0) {
+      auto OffsetC = B.buildConstant(CmpOperandTy, Offset);
+      auto Add = B.buildAdd(CmpOperandTy, R1, OffsetC, Flags);
+      auto NewCon = B.buildConstant(CmpOperandTy, NewC);
+      auto ICmp = B.buildICmp(NewPred, CmpTy, Add, NewCon);
+      B.buildZExtOrTrunc(DstReg, ICmp);
+    } else if (!CreateMask && Offset == 0) {
+      auto NewCon = B.buildConstant(CmpOperandTy, NewC);
+      auto ICmp = B.buildICmp(NewPred, CmpTy, R1, NewCon);
+      B.buildZExtOrTrunc(DstReg, ICmp);
+    } else {
+      assert(false && "unexpected configuration of CreateMask and Offset");
+    }
+  };
+  return true;
+}
+
+bool CombinerHelper::matchAnd(MachineInstr &MI, BuildFnTy &MatchInfo) {
+  GAnd *And = cast<GAnd>(&MI);
+
+  if (tryFoldAndOrOrICmpsUsingRanges(And, MatchInfo))
+    return true;
+
+  return false;
+}
+
+bool CombinerHelper::matchOr(MachineInstr &MI, BuildFnTy &MatchInfo) {
+  GOr *Or = cast<GOr>(&MI);
+
+  if (tryFoldAndOrOrICmpsUsingRanges(Or, MatchInfo))
----------------
arsenm wrote:

Ditto 

https://github.com/llvm/llvm-project/pull/77855


More information about the llvm-commits mailing list