[llvm-branch-commits] [llvm] d49f649 - [AArch64][GlobalISel] Refactor G_BRCOND selection

Jessica Paquette via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Mon Dec 7 17:28:40 PST 2020


Author: Jessica Paquette
Date: 2020-12-07T17:24:23-08:00
New Revision: d49f6491b6d1439b5a65ff6e965b65a66d943b63

URL: https://github.com/llvm/llvm-project/commit/d49f6491b6d1439b5a65ff6e965b65a66d943b63
DIFF: https://github.com/llvm/llvm-project/commit/d49f6491b6d1439b5a65ff6e965b65a66d943b63.diff

LOG: [AArch64][GlobalISel] Refactor G_BRCOND selection

`selectCompareBranch` was hard to understand.

Also, it was being needlessly pessimistic with the `ProduceNonFlagSettingCondBr`
case. It assumed that everything in `selectCompareBranch` would emit a TB(N)Z
or C(B)NZ. That's not true; the G_FCMP + G_BRCOND case would never emit those
instructions, and the G_ICMP + G_BRCOND case was capable of emitting an integer
compare + Bcc.

- Refactor `selectCompareBranch` into separate functions based off of what is
feeding the G_BRCOND's condition.

- Move G_BRCOND selection code from `select` to `selectCompareBranch`.

- Remove duplicated constraint code from the code originally in `select`;
  `emitTestBit` already handles that, so no need to constrain twice.

- Factor out the G_FCMP + G_BRCOND case into `selectCompareBranchFedByFCmp`.

- Split the G_ICMP + G_BRCOND case into an optimization function,
`tryOptCompareBranchFedByICmp` and a general selection function,
`selectCompareBranchFedByICmp`.

- Reduce the number of things passed to `tryOptAndIntoCompareBranch`.

- Improve documentation.

- Give some variables more descriptive names.

Other than improving the code generation for functions with
speculative_load_hardening by getting the logic correct, this is NFC.

Differential Revision: https://reviews.llvm.org/D92582

Added: 
    

Modified: 
    llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp
    llvm/test/CodeGen/AArch64/GlobalISel/speculative-hardening-brcond.mir

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp b/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp
index 982de35aeef1..abb0cb7b4299 100644
--- a/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp
+++ b/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp
@@ -102,11 +102,19 @@ class AArch64InstructionSelector : public InstructionSelector {
   bool selectVaStartDarwin(MachineInstr &I, MachineFunction &MF,
                            MachineRegisterInfo &MRI) const;
 
-  bool tryOptAndIntoCompareBranch(MachineInstr *LHS,
-                                  int64_t CmpConstant,
-                                  const CmpInst::Predicate &Pred,
+  ///@{
+  /// Helper functions for selectCompareBranch.
+  bool selectCompareBranchFedByFCmp(MachineInstr &I, MachineInstr &FCmp,
+                                    MachineIRBuilder &MIB) const;
+  bool selectCompareBranchFedByICmp(MachineInstr &I, MachineInstr &ICmp,
+                                    MachineIRBuilder &MIB) const;
+  bool tryOptCompareBranchFedByICmp(MachineInstr &I, MachineInstr &ICmp,
+                                    MachineIRBuilder &MIB) const;
+  bool tryOptAndIntoCompareBranch(MachineInstr &AndInst, bool Invert,
                                   MachineBasicBlock *DstMBB,
                                   MachineIRBuilder &MIB) const;
+  ///@}
+
   bool selectCompareBranch(MachineInstr &I, MachineFunction &MF,
                            MachineRegisterInfo &MRI) const;
 
@@ -1369,8 +1377,9 @@ MachineInstr *AArch64InstructionSelector::emitTestBit(
 }
 
 bool AArch64InstructionSelector::tryOptAndIntoCompareBranch(
-    MachineInstr *AndInst, int64_t CmpConstant, const CmpInst::Predicate &Pred,
-    MachineBasicBlock *DstMBB, MachineIRBuilder &MIB) const {
+    MachineInstr &AndInst, bool Invert, MachineBasicBlock *DstMBB,
+    MachineIRBuilder &MIB) const {
+  assert(AndInst.getOpcode() == TargetOpcode::G_AND && "Expected G_AND only?");
   // Given something like this:
   //
   //  %x = ...Something...
@@ -1388,31 +1397,17 @@ bool AArch64InstructionSelector::tryOptAndIntoCompareBranch(
   //
   // TBNZ %x %bb.3
   //
-  if (!AndInst || AndInst->getOpcode() != TargetOpcode::G_AND)
-    return false;
-
-  // Need to be comparing against 0 to fold.
-  if (CmpConstant != 0)
-    return false;
-
-  MachineRegisterInfo &MRI = *MIB.getMRI();
-
-  // Only support EQ and NE. If we have LT, then it *is* possible to fold, but
-  // we don't want to do this. When we have an AND and LT, we need a TST/ANDS,
-  // so folding would be redundant.
-  assert(ICmpInst::isEquality(Pred) && "Expected only eq/ne?");
 
   // Check if the AND has a constant on its RHS which we can use as a mask.
   // If it's a power of 2, then it's the same as checking a specific bit.
   // (e.g, ANDing with 8 == ANDing with 000...100 == testing if bit 3 is set)
-  auto MaybeBit =
-      getConstantVRegValWithLookThrough(AndInst->getOperand(2).getReg(), MRI);
+  auto MaybeBit = getConstantVRegValWithLookThrough(
+      AndInst.getOperand(2).getReg(), *MIB.getMRI());
   if (!MaybeBit || !isPowerOf2_64(MaybeBit->Value))
     return false;
 
   uint64_t Bit = Log2_64(static_cast<uint64_t>(MaybeBit->Value));
-  Register TestReg = AndInst->getOperand(1).getReg();
-  bool Invert = Pred == CmpInst::Predicate::ICMP_NE;
+  Register TestReg = AndInst.getOperand(1).getReg();
 
   // Emit a TB(N)Z.
   emitTestBit(TestReg, Bit, Invert, DstMBB, MIB);
@@ -1440,47 +1435,54 @@ MachineInstr *AArch64InstructionSelector::emitCBZ(Register CompareReg,
   return &*BranchMI;
 }
 
-bool AArch64InstructionSelector::selectCompareBranch(
-    MachineInstr &I, MachineFunction &MF, MachineRegisterInfo &MRI) const {
-
-  const Register CondReg = I.getOperand(0).getReg();
+bool AArch64InstructionSelector::selectCompareBranchFedByFCmp(
+    MachineInstr &I, MachineInstr &FCmp, MachineIRBuilder &MIB) const {
+  assert(FCmp.getOpcode() == TargetOpcode::G_FCMP);
+  assert(I.getOpcode() == TargetOpcode::G_BRCOND);
+  // Unfortunately, the mapping of LLVM FP CC's onto AArch64 CC's isn't
+  // totally clean.  Some of them require two branches to implement.
+  emitFPCompare(FCmp.getOperand(2).getReg(), FCmp.getOperand(3).getReg(), MIB);
+  AArch64CC::CondCode CC1, CC2;
+  changeFCMPPredToAArch64CC(
+      static_cast<CmpInst::Predicate>(FCmp.getOperand(1).getPredicate()), CC1,
+      CC2);
   MachineBasicBlock *DestMBB = I.getOperand(1).getMBB();
-  MachineInstr *CCMI = MRI.getVRegDef(CondReg);
-  if (CCMI->getOpcode() == TargetOpcode::G_TRUNC)
-    CCMI = MRI.getVRegDef(CCMI->getOperand(1).getReg());
+  MIB.buildInstr(AArch64::Bcc, {}, {}).addImm(CC1).addMBB(DestMBB);
+  if (CC2 != AArch64CC::AL)
+    MIB.buildInstr(AArch64::Bcc, {}, {}).addImm(CC2).addMBB(DestMBB);
+  I.eraseFromParent();
+  return true;
+}
 
-  unsigned CCMIOpc = CCMI->getOpcode();
-  if (CCMIOpc != TargetOpcode::G_ICMP && CCMIOpc != TargetOpcode::G_FCMP)
+bool AArch64InstructionSelector::tryOptCompareBranchFedByICmp(
+    MachineInstr &I, MachineInstr &ICmp, MachineIRBuilder &MIB) const {
+  assert(ICmp.getOpcode() == TargetOpcode::G_ICMP);
+  assert(I.getOpcode() == TargetOpcode::G_BRCOND);
+  // Attempt to optimize the G_BRCOND + G_ICMP into a TB(N)Z/CB(N)Z.
+  //
+  // Speculation tracking/SLH assumes that optimized TB(N)Z/CB(N)Z
+  // instructions will not be produced, as they are conditional branch
+  // instructions that do not set flags.
+  if (!ProduceNonFlagSettingCondBr)
     return false;
 
-  MachineIRBuilder MIB(I);
-  Register LHS = CCMI->getOperand(2).getReg();
-  Register RHS = CCMI->getOperand(3).getReg();
+  MachineRegisterInfo &MRI = *MIB.getMRI();
+  MachineBasicBlock *DestMBB = I.getOperand(1).getMBB();
   auto Pred =
-      static_cast<CmpInst::Predicate>(CCMI->getOperand(1).getPredicate());
-
-  if (CCMIOpc == TargetOpcode::G_FCMP) {
-    // Unfortunately, the mapping of LLVM FP CC's onto AArch64 CC's isn't
-    // totally clean.  Some of them require two branches to implement.
-    emitFPCompare(LHS, RHS, MIB);
-    AArch64CC::CondCode CC1, CC2;
-    changeFCMPPredToAArch64CC(Pred, CC1, CC2);
-    MIB.buildInstr(AArch64::Bcc, {}, {}).addImm(CC1).addMBB(DestMBB);
-    if (CC2 != AArch64CC::AL)
-      MIB.buildInstr(AArch64::Bcc, {}, {}).addImm(CC2).addMBB(DestMBB);
-    I.eraseFromParent();
-    return true;
-  }
+      static_cast<CmpInst::Predicate>(ICmp.getOperand(1).getPredicate());
+  Register LHS = ICmp.getOperand(2).getReg();
+  Register RHS = ICmp.getOperand(3).getReg();
 
+  // We're allowed to emit a TB(N)Z/CB(N)Z. Try to do that.
   auto VRegAndVal = getConstantVRegValWithLookThrough(RHS, MRI);
-  MachineInstr *LHSMI = getDefIgnoringCopies(LHS, MRI);
+  MachineInstr *AndInst = getOpcodeDef(TargetOpcode::G_AND, LHS, MRI);
 
   // When we can emit a TB(N)Z, prefer that.
   //
   // Handle non-commutative condition codes first.
   // Note that we don't want to do this when we have a G_AND because it can
   // become a tst. The tst will make the test bit in the TB(N)Z redundant.
-  if (VRegAndVal && LHSMI->getOpcode() != TargetOpcode::G_AND) {
+  if (VRegAndVal && !AndInst) {
     int64_t C = VRegAndVal->Value;
 
     // When we have a greater-than comparison, we can just test if the msb is
@@ -1508,14 +1510,19 @@ bool AArch64InstructionSelector::selectCompareBranch(
     if (!VRegAndVal) {
       std::swap(RHS, LHS);
       VRegAndVal = getConstantVRegValWithLookThrough(RHS, MRI);
-      LHSMI = getDefIgnoringCopies(LHS, MRI);
+      AndInst = getOpcodeDef(TargetOpcode::G_AND, LHS, MRI);
     }
 
     if (VRegAndVal && VRegAndVal->Value == 0) {
       // If there's a G_AND feeding into this branch, try to fold it away by
       // emitting a TB(N)Z instead.
-      if (tryOptAndIntoCompareBranch(LHSMI, VRegAndVal->Value, Pred, DestMBB,
-                                     MIB)) {
+      //
+      // Note: If we have LT, then it *is* possible to fold, but it wouldn't be
+      // beneficial. When we have an AND and LT, we need a TST/ANDS, so folding
+      // would be redundant.
+      if (AndInst &&
+          tryOptAndIntoCompareBranch(
+              *AndInst, /*Invert = */ Pred == CmpInst::ICMP_NE, DestMBB, MIB)) {
         I.eraseFromParent();
         return true;
       }
@@ -1530,15 +1537,66 @@ bool AArch64InstructionSelector::selectCompareBranch(
     }
   }
 
-  // Couldn't optimize. Emit a compare + bcc.
-  emitIntegerCompare(CCMI->getOperand(2), CCMI->getOperand(3),
-                     CCMI->getOperand(1), MIB);
-  const AArch64CC::CondCode CC = changeICMPPredToAArch64CC(Pred);
+  return false;
+}
+
+bool AArch64InstructionSelector::selectCompareBranchFedByICmp(
+    MachineInstr &I, MachineInstr &ICmp, MachineIRBuilder &MIB) const {
+  assert(ICmp.getOpcode() == TargetOpcode::G_ICMP);
+  assert(I.getOpcode() == TargetOpcode::G_BRCOND);
+  if (tryOptCompareBranchFedByICmp(I, ICmp, MIB))
+    return true;
+
+  // Couldn't optimize. Emit a compare + a Bcc.
+  MachineBasicBlock *DestMBB = I.getOperand(1).getMBB();
+  auto PredOp = ICmp.getOperand(1);
+  emitIntegerCompare(ICmp.getOperand(2), ICmp.getOperand(3), PredOp, MIB);
+  const AArch64CC::CondCode CC = changeICMPPredToAArch64CC(
+      static_cast<CmpInst::Predicate>(PredOp.getPredicate()));
   MIB.buildInstr(AArch64::Bcc, {}, {}).addImm(CC).addMBB(DestMBB);
   I.eraseFromParent();
   return true;
 }
 
+bool AArch64InstructionSelector::selectCompareBranch(
+    MachineInstr &I, MachineFunction &MF, MachineRegisterInfo &MRI) const {
+  Register CondReg = I.getOperand(0).getReg();
+  MachineInstr *CCMI = MRI.getVRegDef(CondReg);
+  if (CCMI->getOpcode() == TargetOpcode::G_TRUNC) {
+    CondReg = CCMI->getOperand(1).getReg();
+    CCMI = MRI.getVRegDef(CondReg);
+  }
+
+  // Try to select the G_BRCOND using whatever is feeding the condition if
+  // possible.
+  MachineIRBuilder MIB(I);
+  unsigned CCMIOpc = CCMI->getOpcode();
+  if (CCMIOpc == TargetOpcode::G_FCMP)
+    return selectCompareBranchFedByFCmp(I, *CCMI, MIB);
+  if (CCMIOpc == TargetOpcode::G_ICMP)
+    return selectCompareBranchFedByICmp(I, *CCMI, MIB);
+
+  // Speculation tracking/SLH assumes that optimized TB(N)Z/CB(N)Z
+  // instructions will not be produced, as they are conditional branch
+  // instructions that do not set flags.
+  if (ProduceNonFlagSettingCondBr) {
+    emitTestBit(CondReg, /*Bit = */ 0, /*IsNegative = */ true,
+                I.getOperand(1).getMBB(), MIB);
+    I.eraseFromParent();
+    return true;
+  }
+
+  // Can't emit TB(N)Z/CB(N)Z. Emit a tst + bcc instead.
+  auto TstMI =
+      MIB.buildInstr(AArch64::ANDSWri, {LLT::scalar(32)}, {CondReg}).addImm(1);
+  constrainSelectedInstRegOperands(*TstMI, TII, TRI, RBI);
+  auto Bcc = MIB.buildInstr(AArch64::Bcc)
+                 .addImm(AArch64CC::EQ)
+                 .addMBB(I.getOperand(1).getMBB());
+  I.eraseFromParent();
+  return constrainSelectedInstRegOperands(*Bcc, TII, TRI, RBI);
+}
+
 /// Returns the element immediate value of a vector shift operand if found.
 /// This needs to detect a splat-like operation, e.g. a G_BUILD_VECTOR.
 static Optional<int64_t> getVectorShiftImm(Register Reg,
@@ -2108,31 +2166,8 @@ bool AArch64InstructionSelector::select(MachineInstr &I) {
   MachineIRBuilder MIB(I);
 
   switch (Opcode) {
-  case TargetOpcode::G_BRCOND: {
-    Register CondReg = I.getOperand(0).getReg();
-    MachineBasicBlock *DestMBB = I.getOperand(1).getMBB();
-
-    // Speculation tracking/SLH assumes that optimized TB(N)Z/CB(N)Z
-    // instructions will not be produced, as they are conditional branch
-    // instructions that do not set flags.
-    if (ProduceNonFlagSettingCondBr && selectCompareBranch(I, MF, MRI))
-      return true;
-
-    if (ProduceNonFlagSettingCondBr) {
-      auto TestBit = emitTestBit(CondReg, /*Bit = */ 0, /*IsNegative = */ true,
-                                 DestMBB, MIB);
-      I.eraseFromParent();
-      return constrainSelectedInstRegOperands(*TestBit, TII, TRI, RBI);
-    } else {
-      auto CMP = MIB.buildInstr(AArch64::ANDSWri, {LLT::scalar(32)}, {CondReg})
-                     .addImm(1);
-      constrainSelectedInstRegOperands(*CMP, TII, TRI, RBI);
-      auto Bcc =
-          MIB.buildInstr(AArch64::Bcc).addImm(AArch64CC::EQ).addMBB(DestMBB);
-      I.eraseFromParent();
-      return constrainSelectedInstRegOperands(*Bcc.getInstr(), TII, TRI, RBI);
-    }
-  }
+  case TargetOpcode::G_BRCOND:
+    return selectCompareBranch(I, MF, MRI);
 
   case TargetOpcode::G_BRINDIRECT: {
     I.setDesc(TII.get(AArch64::BR));

diff  --git a/llvm/test/CodeGen/AArch64/GlobalISel/speculative-hardening-brcond.mir b/llvm/test/CodeGen/AArch64/GlobalISel/speculative-hardening-brcond.mir
index d665f962f1d4..7a688d1325f4 100644
--- a/llvm/test/CodeGen/AArch64/GlobalISel/speculative-hardening-brcond.mir
+++ b/llvm/test/CodeGen/AArch64/GlobalISel/speculative-hardening-brcond.mir
@@ -8,6 +8,7 @@
 --- |
     define void @no_tbnz() speculative_load_hardening { ret void }
     define void @no_cbz() speculative_load_hardening { ret void }
+    define void @fp() speculative_load_hardening { ret void }
 ...
 
 ---
@@ -44,8 +45,6 @@ body:             |
   ; CHECK:   successors: %bb.0(0x40000000), %bb.1(0x40000000)
   ; CHECK:   %reg:gpr32sp = COPY $w0
   ; CHECK:   [[SUBSWri:%[0-9]+]]:gpr32 = SUBSWri %reg, 0, 0, implicit-def $nzcv
-  ; CHECK:   %cmp:gpr32 = CSINCWr $wzr, $wzr, 1, implicit $nzcv
-  ; CHECK:   [[ANDSWri:%[0-9]+]]:gpr32 = ANDSWri %cmp, 1, implicit-def $nzcv
   ; CHECK:   Bcc 0, %bb.1, implicit $nzcv
   ; CHECK:   B %bb.0
   ; CHECK: bb.1:
@@ -62,3 +61,29 @@ body:             |
   bb.1:
     RET_ReallyLR
 ...
+---
+name:            fp
+legalized:       true
+regBankSelected: true
+body:             |
+  ; CHECK-LABEL: name: fp
+  ; CHECK: bb.0:
+  ; CHECK:   successors: %bb.0(0x40000000), %bb.1(0x40000000)
+  ; CHECK:   %reg0:fpr32 = COPY $s0
+  ; CHECK:   %reg1:fpr32 = COPY $s1
+  ; CHECK:   FCMPSrr %reg0, %reg1, implicit-def $nzcv
+  ; CHECK:   Bcc 0, %bb.1, implicit $nzcv
+  ; CHECK:   B %bb.0
+  ; CHECK: bb.1:
+  ; CHECK:   RET_ReallyLR
+  bb.0:
+    liveins: $s0, $s1
+    successors: %bb.0, %bb.1
+    %reg0:fpr(s32) = COPY $s0
+    %reg1:fpr(s32) = COPY $s1
+    %cmp:gpr(s32) = G_FCMP floatpred(oeq), %reg0, %reg1
+    %cond:gpr(s1) = G_TRUNC %cmp(s32)
+    G_BRCOND %cond(s1), %bb.1
+    G_BR %bb.0
+  bb.1:
+    RET_ReallyLR


        


More information about the llvm-branch-commits mailing list