[llvm] [AArch64][FEAT_CMPBR] Codegen for Armv9.6-a compare-and-branch (PR #116465)

via llvm-commits llvm-commits at lists.llvm.org
Fri Nov 15 19:20:07 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-aarch64

Author: David Tellenbach (dtellenbach)

<details>
<summary>Changes</summary>

This patch adds codegen for all Arm9.6-a compare-and-branch instructions, that operate on full w or x registers. The instruction variants operating on half-words (cbh) and bytes (cbb) are added in a subsequent patch.

Since CB doesn't use standard 4-bit Arm condition codes but a reduced set of conditions, encoded in 3 bits, some conditions are expressed by modifying operands, namely incrementing or decrementing immediate operands and swapping register operands. To invert a CB instruction it's therefore not enough to just modify the condition code which doesn't play particularly well with how the backend is currently organized. We therefore introduce a number of pseudos which operate on the standard 4-bit condition codes and lower them late during codegen.

---

Patch is 48.23 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/116465.diff


13 Files Affected:

- (modified) llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp (+155) 
- (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+24) 
- (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.h (+4) 
- (modified) llvm/lib/Target/AArch64/AArch64InstrFormats.td (+19) 
- (modified) llvm/lib/Target/AArch64/AArch64InstrInfo.cpp (+95-1) 
- (modified) llvm/lib/Target/AArch64/AArch64InstrInfo.h (+4) 
- (modified) llvm/lib/Target/AArch64/AArch64InstrInfo.td (+25) 
- (modified) llvm/lib/Target/AArch64/Utils/AArch64BaseInfo.h (+20) 
- (added) llvm/test/CodeGen/AArch64/cmpbr-branch-relaxation.mir (+156) 
- (added) llvm/test/CodeGen/AArch64/cmpbr-early-ifcvt.mir (+116) 
- (added) llvm/test/CodeGen/AArch64/cmpbr-reg-imm-bounds.ll (+66) 
- (added) llvm/test/CodeGen/AArch64/cmpbr-reg-imm.ll (+402) 
- (added) llvm/test/CodeGen/AArch64/cmpbr-reg-reg.ll (+405) 


``````````diff
diff --git a/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp b/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp
index af26fc62292377..0a403f077f23b6 100644
--- a/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp
+++ b/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp
@@ -181,6 +181,9 @@ class AArch64AsmPrinter : public AsmPrinter {
   /// pseudo instructions.
   bool lowerPseudoInstExpansion(const MachineInstr *MI, MCInst &Inst);
 
+  // Emit expansion of Compare-and-branch pseudo instructions
+  void emitCBPseudoExpansion(const MachineInstr *MI);
+
   void EmitToStreamer(MCStreamer &S, const MCInst &Inst);
   void EmitToStreamer(const MCInst &Inst) {
     EmitToStreamer(*OutStreamer, Inst);
@@ -2427,6 +2430,151 @@ AArch64AsmPrinter::lowerBlockAddressConstant(const BlockAddress &BA) {
   return BAE;
 }
 
+void AArch64AsmPrinter::emitCBPseudoExpansion(const MachineInstr *MI) {
+  bool IsImm = false;
+  bool Is32Bit = false;
+
+  switch (MI->getOpcode()) {
+  default:
+    llvm_unreachable("This is not a CB pseudo instruction");
+  case AArch64::CBWPrr:
+    IsImm = false;
+    Is32Bit = true;
+    break;
+  case AArch64::CBXPrr:
+    IsImm = false;
+    Is32Bit = false;
+    break;
+  case AArch64::CBWPri:
+    IsImm = true;
+    Is32Bit = true;
+    break;
+  case AArch64::CBXPri:
+    IsImm = true;
+    Is32Bit = false;
+    break;
+  }
+
+  AArch64CC::CondCode CC =
+      static_cast<AArch64CC::CondCode>(MI->getOperand(0).getImm());
+  bool NeedsRegSwap = false;
+  bool NeedsImmDec = false;
+  bool NeedsImmInc = false;
+
+
+  unsigned MCOpC;
+  switch(CC) {
+  default:
+    llvm_unreachable("Invalid CB condition code");
+  case AArch64CC::EQ:
+    MCOpC = IsImm ? (Is32Bit ? AArch64::CBEQWri : AArch64::CBEQXri)
+                  : (Is32Bit ? AArch64::CBEQWrr : AArch64::CBEQXrr);
+    NeedsRegSwap = false;
+    NeedsImmDec = false;
+    NeedsImmInc = false;
+    break;
+  case AArch64CC::NE:
+    MCOpC = IsImm ? (Is32Bit ? AArch64::CBNEWri : AArch64::CBNEXri)
+                  : (Is32Bit ? AArch64::CBNEWrr : AArch64::CBNEXrr);
+    NeedsRegSwap = false;
+    NeedsImmDec = false;
+    NeedsImmInc = false;
+    break;
+  case AArch64CC::HS:
+    MCOpC = IsImm ? (Is32Bit ? AArch64::CBHIWri : AArch64::CBHIXri)
+                  : (Is32Bit ? AArch64::CBHSWrr : AArch64::CBHSXrr);
+    NeedsRegSwap = false;
+    NeedsImmDec = true;
+    NeedsImmInc = false;
+    break;
+  case AArch64CC::LO:
+    MCOpC = IsImm ? (Is32Bit ? AArch64::CBLOWri : AArch64::CBLOXri)
+                  : (Is32Bit ? AArch64::CBHIWrr : AArch64::CBHIXrr);
+    NeedsRegSwap = true;
+    NeedsImmDec = false;
+    NeedsImmInc = false;
+    break;
+  case AArch64CC::HI:
+    MCOpC = IsImm ? (Is32Bit ? AArch64::CBHIWri : AArch64::CBHIXri)
+                  : (Is32Bit ? AArch64::CBHIWrr : AArch64::CBHIXrr);
+    NeedsRegSwap = false;
+    NeedsImmDec = false;
+    NeedsImmInc = false;
+    break;
+  case AArch64CC::LS:
+    MCOpC = IsImm ? (Is32Bit ? AArch64::CBLOWri : AArch64::CBLOXri)
+                  : (Is32Bit ? AArch64::CBHSWrr : AArch64::CBHSXrr);
+    NeedsRegSwap = !IsImm;
+    NeedsImmDec = false;
+    NeedsImmInc = IsImm;
+    break;
+  case AArch64CC::GE:
+    MCOpC = IsImm ? (Is32Bit ? AArch64::CBGTWri : AArch64::CBGTXri)
+                  : (Is32Bit ? AArch64::CBGEWrr : AArch64::CBGEXrr);
+    NeedsRegSwap = false;
+    NeedsImmDec = IsImm;
+    NeedsImmInc = false;
+    break;
+  case AArch64CC::LT:
+    MCOpC = IsImm ? (Is32Bit ? AArch64::CBLTWri : AArch64::CBLTXri)
+                  : (Is32Bit ? AArch64::CBGTWrr : AArch64::CBGTXrr);
+    NeedsRegSwap = !IsImm;
+    NeedsImmDec = false;
+    NeedsImmInc = false;
+    break;
+  case AArch64CC::GT:
+    MCOpC = IsImm ? (Is32Bit ? AArch64::CBGTWri : AArch64::CBGTXri)
+                  : (Is32Bit ? AArch64::CBGTWrr : AArch64::CBGTXrr);
+    NeedsRegSwap = false;
+    NeedsImmDec = false;
+    NeedsImmInc = false;
+    break;
+  case AArch64CC::LE:
+    MCOpC = IsImm ? (Is32Bit ? AArch64::CBLTWri : AArch64::CBLTXri)
+                  : (Is32Bit ? AArch64::CBGEWrr : AArch64::CBGEXrr);
+    NeedsRegSwap = !IsImm;
+    NeedsImmDec = false;
+    NeedsImmInc = IsImm;
+    break;
+  }
+
+  MCInst Inst;
+  Inst.setOpcode(MCOpC);
+
+  MCOperand Lhs, Rhs, Trgt;
+  lowerOperand(MI->getOperand(1), Lhs);
+  lowerOperand(MI->getOperand(2), Rhs);
+  lowerOperand(MI->getOperand(3), Trgt);
+
+  if (NeedsRegSwap) {
+    assert(
+        !IsImm &&
+        "Unexpected register swap for CB instruction with immediate operand");
+    assert(Lhs.isReg() && "Expected register operand for CB");
+    assert(Rhs.isReg() && "Expected register operand for CB");
+    // Swap register operands
+    Inst.addOperand(Rhs);
+    Inst.addOperand(Lhs);
+  } else if (IsImm && NeedsImmDec) {
+    assert(IsImm && "Unexpected immediate decrement for CB instruction with "
+                    "reg-reg operands");
+    Rhs.setImm(Rhs.getImm() - 1);
+    Inst.addOperand(Lhs);
+    Inst.addOperand(Rhs);
+  } else if (NeedsImmInc) {
+    assert(IsImm && "Unexpected immediate increment for CB instruction with "
+                    "reg-reg operands");
+    Rhs.setImm(Rhs.getImm() + 1);
+    Inst.addOperand(Lhs);
+    Inst.addOperand(Rhs);
+  } else {
+    Inst.addOperand(Lhs);
+    Inst.addOperand(Rhs);
+  }
+  Inst.addOperand(Trgt);
+  EmitToStreamer(*OutStreamer, Inst);
+}
+
 // Simple pseudo-instructions have their lowering (with expansion to real
 // instructions) auto-generated.
 #include "AArch64GenMCPseudoLowering.inc"
@@ -2948,6 +3096,13 @@ void AArch64AsmPrinter::emitInstruction(const MachineInstr *MI) {
     TS->emitARM64WinCFISaveAnyRegQPX(MI->getOperand(0).getImm(),
                                      -MI->getOperand(2).getImm());
     return;
+
+  case AArch64::CBWPri:
+  case AArch64::CBXPri:
+  case AArch64::CBWPrr:
+  case AArch64::CBXPrr:
+    emitCBPseudoExpansion(MI);
+    return;
   }
 
   // Finally, do the automated lowerings for everything else.
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 9d1c3d4eddc880..3e35a85d4fe806 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -2954,6 +2954,8 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
     MAKE_CASE(AArch64ISD::CTTZ_ELTS)
     MAKE_CASE(AArch64ISD::CALL_ARM64EC_TO_X64)
     MAKE_CASE(AArch64ISD::URSHR_I_PRED)
+    MAKE_CASE(AArch64ISD::CBRR)
+    MAKE_CASE(AArch64ISD::CBRI)
   }
 #undef MAKE_CASE
   return nullptr;
@@ -10396,6 +10398,28 @@ SDValue AArch64TargetLowering::LowerBR_CC(SDValue Op, SelectionDAG &DAG) const {
                          DAG.getConstant(SignBitPos, dl, MVT::i64), Dest);
     }
 
+    // Try to emit Armv9.6 CB instructions. We prefer tb{n}z/cb{n}z due to their
+    // larger branch displacement but do prefer CB over cmp + br.
+    if (Subtarget->hasCMPBR() &&
+        AArch64CC::isValidCBCond(changeIntCCToAArch64CC(CC)) &&
+        ProduceNonFlagSettingCondBr) {
+      AArch64CC::CondCode ACC = changeIntCCToAArch64CC(CC);
+      unsigned Opc = AArch64ISD::CBRR;
+      if (ConstantSDNode *Imm = dyn_cast<ConstantSDNode>(RHS)) {
+        APInt NewImm = Imm->getAPIntValue();
+        if (ACC == AArch64CC::GE || ACC == AArch64CC::HS)
+          NewImm = Imm->getAPIntValue() - 1;
+        else if (ACC == AArch64CC::LE || ACC == AArch64CC::LS)
+          NewImm = Imm->getAPIntValue() + 1;
+
+        if (NewImm.uge(0) && NewImm.ult(64))
+          Opc = AArch64ISD::CBRI;
+      }
+
+      SDValue Cond = DAG.getTargetConstant(ACC, dl, MVT::i32);
+      return DAG.getNode(Opc, dl, MVT::Other, Chain, Cond, LHS, RHS, Dest);
+    }
+
     SDValue CCVal;
     SDValue Cmp = getAArch64Cmp(LHS, RHS, CC, CCVal, DAG, dl);
     return DAG.getNode(AArch64ISD::BRCOND, dl, MVT::Other, Chain, Dest, CCVal,
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index d11da64d3f84eb..7de5f4490e78db 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -520,6 +520,10 @@ enum NodeType : unsigned {
   MOPS_MEMSET_TAGGING,
   MOPS_MEMCOPY,
   MOPS_MEMMOVE,
+
+  // Compare-and-branch
+  CBRR,
+  CBRI,
 };
 
 } // end namespace AArch64ISD
diff --git a/llvm/lib/Target/AArch64/AArch64InstrFormats.td b/llvm/lib/Target/AArch64/AArch64InstrFormats.td
index 15d4e93b915c14..ca2bfae8d7e8a0 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrFormats.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrFormats.td
@@ -13065,6 +13065,7 @@ class BaseCmpBranchRegister<RegisterClass regtype, bit sf, bits<3> cc,
       Sched<[WriteBr]> {
   let isBranch = 1;
   let isTerminator = 1;
+  let isCompare = 1;
 
   bits<5> Rm;
   bits<5> Rt;
@@ -13091,6 +13092,7 @@ class BaseCmpBranchImmediate<RegisterClass regtype, bit sf, bits<3> cc,
       Sched<[WriteBr]> {
   let isBranch = 1;
   let isTerminator = 1;
+  let isCompare = 1;
 
   bits<5> Rt;
   bits<6> imm;
@@ -13131,6 +13133,23 @@ multiclass CmpBranchRegisterAlias<string mnemonic, string insn> {
   def : InstAlias<mnemonic # "\t$Rt, $Rm, $target",
                  (!cast<Instruction>(insn # "Xrr") GPR64:$Rm, GPR64:$Rt, am_brcmpcond:$target), 0>;
 }
+
+class CmpBranchRegisterPseudo<RegisterClass regtype>
+  : Pseudo<(outs), (ins ccode:$Cond, regtype:$Rt, regtype:$Rm, am_brcmpcond:$Target), []>,
+    Sched<[WriteBr]> {
+  let isBranch = 1;
+  let isTerminator = 1;
+  let isCompare = 1;
+}
+
+class CmpBranchImmediatePseudo<RegisterClass regtype, ImmLeaf imtype>
+  : Pseudo<(outs), (ins ccode:$Cond, regtype:$Rt, imtype:$Imm, am_brcmpcond:$Target), []>,
+    Sched<[WriteBr]> {
+  let isBranch = true;
+  let isTerminator = true;
+  let isCompare = true;
+}
+
 //----------------------------------------------------------------------------
 // Allow the size specifier tokens to be upper case, not just lower.
 def : TokenAlias<".4B", ".4b">;  // Add dot product
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
index a470c03efd5eb4..73cc235982c392 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
@@ -62,6 +62,10 @@ using namespace llvm;
 #define GET_INSTRINFO_CTOR_DTOR
 #include "AArch64GenInstrInfo.inc"
 
+static cl::opt<unsigned>
+    CBDisplacementBits("aarch64-cb-offset-bits", cl::Hidden, cl::init(9),
+                       cl::desc("Restrict range of CB instructions (DEBUG)"));
+
 static cl::opt<unsigned> TBZDisplacementBits(
     "aarch64-tbz-offset-bits", cl::Hidden, cl::init(14),
     cl::desc("Restrict range of TB[N]Z instructions (DEBUG)"));
@@ -216,6 +220,17 @@ static void parseCondBranch(MachineInstr *LastInst, MachineBasicBlock *&Target,
     Cond.push_back(MachineOperand::CreateImm(LastInst->getOpcode()));
     Cond.push_back(LastInst->getOperand(0));
     Cond.push_back(LastInst->getOperand(1));
+    break;
+  case AArch64::CBWPri:
+  case AArch64::CBXPri:
+  case AArch64::CBWPrr:
+  case AArch64::CBXPrr:
+    Target = LastInst->getOperand(3).getMBB();
+    Cond.push_back(MachineOperand::CreateImm(-1));
+    Cond.push_back(MachineOperand::CreateImm(LastInst->getOpcode()));
+    Cond.push_back(LastInst->getOperand(0));
+    Cond.push_back(LastInst->getOperand(1));
+    Cond.push_back(LastInst->getOperand(2));
   }
 }
 
@@ -237,6 +252,11 @@ static unsigned getBranchDisplacementBits(unsigned Opc) {
     return CBZDisplacementBits;
   case AArch64::Bcc:
     return BCCDisplacementBits;
+  case AArch64::CBWPri:
+  case AArch64::CBXPri:
+  case AArch64::CBWPrr:
+  case AArch64::CBXPrr:
+    return CBDisplacementBits;
   }
 }
 
@@ -266,6 +286,11 @@ AArch64InstrInfo::getBranchDestBlock(const MachineInstr &MI) const {
   case AArch64::CBNZX:
   case AArch64::Bcc:
     return MI.getOperand(1).getMBB();
+  case AArch64::CBWPri:
+  case AArch64::CBXPri:
+  case AArch64::CBWPrr:
+  case AArch64::CBXPrr:
+    return MI.getOperand(3).getMBB();
   }
 }
 
@@ -543,6 +568,17 @@ bool AArch64InstrInfo::reverseBranchCondition(
     case AArch64::TBNZX:
       Cond[1].setImm(AArch64::TBZX);
       break;
+
+    // Cond is { -1, Opcode, CC, Op0, Op1 }
+    case AArch64::CBWPri:
+    case AArch64::CBXPri:
+    case AArch64::CBWPrr:
+    case AArch64::CBXPrr: {
+      // Pseudos using standard 4bit Arm condition codes
+      AArch64CC::CondCode CC =
+          static_cast<AArch64CC::CondCode>(Cond[2].getImm());
+      Cond[2].setImm(AArch64CC::getInvertedCondCode(CC));
+    } break;
     }
   }
 
@@ -593,10 +629,19 @@ void AArch64InstrInfo::instantiateCondBranch(
   } else {
     // Folded compare-and-branch
     // Note that we use addOperand instead of addReg to keep the flags.
+
+    // cbz, cbnz
     const MachineInstrBuilder MIB =
         BuildMI(&MBB, DL, get(Cond[1].getImm())).add(Cond[2]);
+
+    // tbz/tbnz
     if (Cond.size() > 3)
-      MIB.addImm(Cond[3].getImm());
+      MIB.add(Cond[3]);
+
+    // cb
+    if (Cond.size() > 4)
+      MIB.add(Cond[4]);
+
     MIB.addMBB(TBB);
   }
 }
@@ -842,6 +887,51 @@ void AArch64InstrInfo::insertSelect(MachineBasicBlock &MBB,
               AArch64_AM::encodeLogicalImmediate(1ull << Cond[3].getImm(), 64));
     break;
   }
+  case 5: { // cb
+    // We must insert a cmp, that is a subs
+    //            0       1   2    3    4
+    // Cond is { -1, Opcode, CC, Op0, Op1 }
+    unsigned SUBSOpC, SUBSDestReg;
+    bool IsImm = false;
+    switch (Cond[1].getImm()) {
+    default:
+      llvm_unreachable("Unknown branch opcode in Cond");
+    case AArch64::CBWPri:
+      SUBSOpC = AArch64::SUBSWri;
+      SUBSDestReg = AArch64::WZR;
+      IsImm = true;
+      CC = static_cast<AArch64CC::CondCode>(Cond[2].getImm());
+      break;
+    case AArch64::CBXPri:
+      SUBSOpC = AArch64::SUBSXri;
+      SUBSDestReg = AArch64::XZR;
+      IsImm = true;
+      CC = static_cast<AArch64CC::CondCode>(Cond[2].getImm());
+      break;
+    case AArch64::CBWPrr:
+      SUBSOpC = AArch64::SUBSWrr;
+      SUBSDestReg = AArch64::WZR;
+      IsImm = false;
+      CC = static_cast<AArch64CC::CondCode>(Cond[2].getImm());
+      break;
+    case AArch64::CBXPrr:
+      SUBSOpC = AArch64::SUBSXrr;
+      SUBSDestReg = AArch64::XZR;
+      IsImm = false;
+      CC = static_cast<AArch64CC::CondCode>(Cond[2].getImm());
+      break;
+    }
+
+    if (IsImm)
+      BuildMI(MBB, I, DL, get(SUBSOpC), SUBSDestReg)
+          .addReg(Cond[3].getReg())
+          .addImm(Cond[4].getImm())
+          .addImm(0);
+    else
+      BuildMI(MBB, I, DL, get(SUBSOpC), SUBSDestReg)
+          .addReg(Cond[3].getReg())
+          .addReg(Cond[4].getReg());
+  }
   }
 
   unsigned Opc = 0;
@@ -8393,6 +8483,10 @@ bool AArch64InstrInfo::optimizeCondBranch(MachineInstr &MI) const {
   default:
     llvm_unreachable("Unknown branch instruction?");
   case AArch64::Bcc:
+  case AArch64::CBWPri:
+  case AArch64::CBXPri:
+  case AArch64::CBWPrr:
+  case AArch64::CBXPrr:
     return false;
   case AArch64::CBZW:
   case AArch64::CBZX:
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.h b/llvm/lib/Target/AArch64/AArch64InstrInfo.h
index e37f70f7d985de..151e397edd6195 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.h
@@ -693,6 +693,10 @@ static inline bool isCondBranchOpcode(int Opc) {
   case AArch64::TBZX:
   case AArch64::TBNZW:
   case AArch64::TBNZX:
+  case AArch64::CBWPri:
+  case AArch64::CBXPri:
+  case AArch64::CBWPrr:
+  case AArch64::CBXPrr:
     return true;
   default:
     return false;
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
index e4ad27d4bcfc00..67efe50bc1f5f3 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
@@ -508,6 +508,9 @@ def SDT_AArch64TBL : SDTypeProfile<1, 2, [
   SDTCisVec<0>, SDTCisSameAs<0, 1>, SDTCisInt<2>
 ]>;
 
+def SDT_AArch64cbrr : SDTypeProfile<0, 4, [SDTCisVT<0, i32>, SDTCisInt<1>, SDTCisSameAs<1, 2>, SDTCisVT<3, OtherVT>]>;
+def SDT_AArch64cbri : SDTypeProfile<0, 4, [SDTCisVT<0, i32>, SDTCisInt<1>, SDTCisInt<2>, SDTCisVT<3, OtherVT>]>;
+
 // non-extending masked load fragment.
 def nonext_masked_load :
   PatFrag<(ops node:$ptr, node:$pred, node:$def),
@@ -684,6 +687,8 @@ def topbitsallzero64: PatLeaf<(i64 GPR64:$src), [{
   }]>;
 
 // Node definitions.
+def AArch64CBrr : SDNode<"AArch64ISD::CBRR", SDT_AArch64cbrr, [SDNPHasChain]>;
+def AArch64CBri : SDNode<"AArch64ISD::CBRI", SDT_AArch64cbri, [SDNPHasChain]>;
 def AArch64adrp          : SDNode<"AArch64ISD::ADRP", SDTIntUnaryOp, []>;
 def AArch64adr           : SDNode<"AArch64ISD::ADR", SDTIntUnaryOp, []>;
 def AArch64addlow        : SDNode<"AArch64ISD::ADDlow", SDTIntBinOp, []>;
@@ -10481,6 +10486,10 @@ defm : PromoteBinaryv8f16Tov4f32<any_fdiv, FDIVv4f32>;
 defm : PromoteBinaryv8f16Tov4f32<any_fmul, FMULv4f32>;
 defm : PromoteBinaryv8f16Tov4f32<any_fsub, FSUBv4f32>;
 
+//===----------------------------------------------------------------------===//
+// Compare and Branch (FEAT_CMPBR)
+//===----------------------------------------------------------------------===//
+
 let Predicates = [HasCMPBR] in {
  defm CBGT : CmpBranchRegister<0b000, "cbgt">;
  defm CBGE : CmpBranchRegister<0b001, "cbge">;
@@ -10529,6 +10538,22 @@ let Predicates = [HasCMPBR] in {
  defm : CmpBranchWRegisterAlias<"cbhlo", "CBHHI">;
  defm : CmpBranchWRegisterAlias<"cbhls", "CBHHS">;
  defm : CmpBranchWRegisterAlias<"cbhlt", "CBHGT">;
+
+  // Pseudos for codegen
+  def CBWPrr : CmpBranchRegisterPseudo<GPR32>;
+  def CBXPrr : CmpBranchRegisterPseudo<GPR64>;
+  def CBWPri : CmpBranchImmediatePseudo<GPR32, uimm6_32b>;
+  def CBXPri : CmpBranchImmediatePseudo<GPR64, uimm6_64b>;
+
+def : Pat<(AArch64CBrr i32:$Cond, GPR32:$Rn, GPR32:$Rt, bb:$Target),
+          (CBWPrr ccode:$Cond, GPR32:$Rn, GPR32:$Rt, am_brcmpcond:$Target)>;
+def : Pat<(AArch64CBrr i32:$Cond, GPR64:$Rn, GPR64:$Rt, bb:$Target),
+          (CBXPrr ccode:$Cond, GPR64:$Rn, GPR64:$Rt, am_brcmpcond:$Target)>;
+def : Pat<(AArch64CBri i32:$Cond, GPR32:$Rn, i32:$Imm, bb:$Target),
+          (CBWPri ccode:$Cond, GPR32:$Rn, uimm6_32b:$Imm, am_brcmpcond:$Target)>;
+def : Pat<(AArch64CBri i32:$Cond, GPR64:$Rn, i64:$Imm, bb:$Target),
+          (CBXPri ccode:$Cond, GPR64:$Rn, uimm6_64b:$Imm, am_brcmpcond:$Target)>;
+
 } // HasCMPBR
 
 
diff --git a/llvm/lib/Target/AArch64/Utils/AArch64BaseInfo.h b/llvm/lib/Target/AArch64/Utils/AArch64BaseInfo.h
index 8f34cf054fe286..417c152eebf24a 100644
--- a/llvm/lib/Target/AArch64/Utils/AArch64BaseInfo.h
+++ b/llvm/lib/Target/AArch64/Utils/AArch64BaseInfo.h
@@ -332,6 +332,26 @@ inline static unsigned getNZCVToSatisfyCondCode(CondCode Code) {
   }
 }
 
+/// True, if a given condition code can be used in a fused compare-and-branch
+/// instructions, false otherwise.
+inline static bool isValidCBCond(AArch64CC::CondCode Code) {
+  switch (Code) {
+  default:
+    return false;
+  case AArch64CC::EQ:
+  case AArch64CC::NE:
+  case AArch64CC::HS:
+  case AArch64CC::LO:
+  case AArch64CC::HI:
+  case AArch64CC::LS:
+  case AArch64CC::GE:
+  case AArch64CC::LT:
+  case AArch64CC::GT:
+  case AArch64CC::LE:
+    return true;
+  }
+}
+
 } // end namespace AArch64CC
 
 struct SysAlias {
diff --git a/llvm/test/CodeGen/AArch64/cmpbr-branch-relaxation.mir b/llvm/test/CodeGen/AArch64/cmpbr-branch-relaxation.mir
new file mode 100644
index 00000000000000..5fccb452e9642b
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/cmpbr-branch-relaxation.mir
@@ -0,0 +1,156 @@
+# NOTE: Assertions have been autogenerated by utils/update_mir_test_checks.py UTC_ARGS: --version 5
+# RUN: llc -mtriple arm64-apple-ios -mattr +cmpbr -o - -aarch64-cb-offset-bits=3 \
+# RUN:    -run-pass=branch-relaxation -verify-machineinstrs -simplify-mir %s | \
+# RUN:    FileCheck -check-prefix=RELAX %s
+# RUN: llc -mtriple arm64-apple-ios -mattr +cmpbr -o - -aarch64-cb-offset-bits=9 \
+# RUN:    -run-pass=branch-relaxation -verify-machineinstrs -simplify-mir %s | \
+# RUN:    FileCheck -check-prefix=NO-RELAX %s
+---
+name:            relax_cb
+registers:
+  - { id: 0, class: gpr32 }
+  - { id: 1, class: gpr32 }
+liveins:
+  - { reg: '$w0', virtual-reg: '%0' }
+  - { reg: '$w1', virtual-reg: '%1' }
+body:             |
+  ; RELAX-LABEL: name: relax_cb
+  ; RELAX: bb.0:
+  ; RELAX-NEXT:   [[COPY:%[0-9]+]]:gpr32 = COPY $w0
+  ; RELAX-NEXT:   [[COPY1:%[0-9]+]]:gpr32 = COPY $w1
+  ; RELAX-NEXT:   ...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/116465


More information about the llvm-commits mailing list