[llvm] 2ec1610 - [AArch64] Peephole rule to remove redundant cmp after cset.

Pavel Iliin via llvm-commits llvm-commits at lists.llvm.org
Mon Apr 19 11:59:08 PDT 2021


Author: Pavel Iliin
Date: 2021-04-19T19:58:38+01:00
New Revision: 2ec16103c68528669080040629961217662353cd

URL: https://github.com/llvm/llvm-project/commit/2ec16103c68528669080040629961217662353cd
DIFF: https://github.com/llvm/llvm-project/commit/2ec16103c68528669080040629961217662353cd.diff

LOG: [AArch64] Peephole rule to remove redundant cmp after cset.

Comparisons to zero or one after cset instructions can be safely
removed in examples like:

cset w9, eq          cset w9, eq
cmp  w9, #1   --->   <removed>
b.ne    .L1          b.ne    .L1

cset w9, eq          cset w9, eq
cmp  w9, #0   --->   <removed>
b.ne    .L1          b.eq    .L1

Peephole optimization to detect suitable cases and get rid of that
comparisons added.

Differential Revision: https://reviews.llvm.org/D98564

Added: 
    

Modified: 
    llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
    llvm/lib/Target/AArch64/AArch64InstrInfo.h
    llvm/test/CodeGen/AArch64/csinc-cmp-removal.mir
    llvm/test/CodeGen/AArch64/f16-instructions.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
index 64adc973beeb..28edf104fc0b 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
@@ -1463,14 +1463,16 @@ bool AArch64InstrInfo::optimizeCompareInstr(
   // FIXME:CmpValue has already been converted to 0 or 1 in analyzeCompare
   // function.
   assert((CmpValue == 0 || CmpValue == 1) && "CmpValue must be 0 or 1!");
-  if (CmpValue != 0 || SrcReg2 != 0)
+  if (SrcReg2 != 0)
     return false;
 
   // CmpInstr is a Compare instruction if destination register is not used.
   if (!MRI->use_nodbg_empty(CmpInstr.getOperand(0).getReg()))
     return false;
 
-  return substituteCmpToZero(CmpInstr, SrcReg, MRI);
+  if (!CmpValue && substituteCmpToZero(CmpInstr, SrcReg, *MRI))
+    return true;
+  return removeCmpToZeroOrOne(CmpInstr, SrcReg, CmpValue, *MRI);
 }
 
 /// Get opcode of S version of Instr.
@@ -1524,13 +1526,44 @@ static unsigned sForm(MachineInstr &Instr) {
 }
 
 /// Check if AArch64::NZCV should be alive in successors of MBB.
-static bool areCFlagsAliveInSuccessors(MachineBasicBlock *MBB) {
+static bool areCFlagsAliveInSuccessors(const MachineBasicBlock *MBB) {
   for (auto *BB : MBB->successors())
     if (BB->isLiveIn(AArch64::NZCV))
       return true;
   return false;
 }
 
+/// \returns The condition code operand index for \p Instr if it is a branch
+/// or select and -1 otherwise.
+static int
+findCondCodeUseOperandIdxForBranchOrSelect(const MachineInstr &Instr) {
+  switch (Instr.getOpcode()) {
+  default:
+    return -1;
+
+  case AArch64::Bcc: {
+    int Idx = Instr.findRegisterUseOperandIdx(AArch64::NZCV);
+    assert(Idx >= 2);
+    return Idx - 2;
+  }
+
+  case AArch64::CSINVWr:
+  case AArch64::CSINVXr:
+  case AArch64::CSINCWr:
+  case AArch64::CSINCXr:
+  case AArch64::CSELWr:
+  case AArch64::CSELXr:
+  case AArch64::CSNEGWr:
+  case AArch64::CSNEGXr:
+  case AArch64::FCSELSrrr:
+  case AArch64::FCSELDrrr: {
+    int Idx = Instr.findRegisterUseOperandIdx(AArch64::NZCV);
+    assert(Idx >= 1);
+    return Idx - 1;
+  }
+  }
+}
+
 namespace {
 
 struct UsedNZCV {
@@ -1556,31 +1589,10 @@ struct UsedNZCV {
 /// Returns AArch64CC::Invalid if either the instruction does not use condition
 /// codes or we don't optimize CmpInstr in the presence of such instructions.
 static AArch64CC::CondCode findCondCodeUsedByInstr(const MachineInstr &Instr) {
-  switch (Instr.getOpcode()) {
-  default:
-    return AArch64CC::Invalid;
-
-  case AArch64::Bcc: {
-    int Idx = Instr.findRegisterUseOperandIdx(AArch64::NZCV);
-    assert(Idx >= 2);
-    return static_cast<AArch64CC::CondCode>(Instr.getOperand(Idx - 2).getImm());
-  }
-
-  case AArch64::CSINVWr:
-  case AArch64::CSINVXr:
-  case AArch64::CSINCWr:
-  case AArch64::CSINCXr:
-  case AArch64::CSELWr:
-  case AArch64::CSELXr:
-  case AArch64::CSNEGWr:
-  case AArch64::CSNEGXr:
-  case AArch64::FCSELSrrr:
-  case AArch64::FCSELDrrr: {
-    int Idx = Instr.findRegisterUseOperandIdx(AArch64::NZCV);
-    assert(Idx >= 1);
-    return static_cast<AArch64CC::CondCode>(Instr.getOperand(Idx - 1).getImm());
-  }
-  }
+  int CCIdx = findCondCodeUseOperandIdxForBranchOrSelect(Instr);
+  return CCIdx >= 0 ? static_cast<AArch64CC::CondCode>(
+                          Instr.getOperand(CCIdx).getImm())
+                    : AArch64CC::Invalid;
 }
 
 static UsedNZCV getUsedNZCV(AArch64CC::CondCode CC) {
@@ -1627,6 +1639,41 @@ static UsedNZCV getUsedNZCV(AArch64CC::CondCode CC) {
   return UsedFlags;
 }
 
+/// \returns Conditions flags used after \p CmpInstr in its MachineBB if they
+/// are not containing C or V flags and NZCV flags are not alive in successors
+/// of the same \p CmpInstr and \p MI parent. \returns None otherwise.
+///
+/// Collect instructions using that flags in \p CCUseInstrs if provided.
+static Optional<UsedNZCV>
+examineCFlagsUse(MachineInstr &MI, MachineInstr &CmpInstr,
+                 const TargetRegisterInfo &TRI,
+                 SmallVectorImpl<MachineInstr *> *CCUseInstrs = nullptr) {
+  MachineBasicBlock *CmpParent = CmpInstr.getParent();
+  if (MI.getParent() != CmpParent)
+    return None;
+
+  if (areCFlagsAliveInSuccessors(CmpParent))
+    return None;
+
+  UsedNZCV NZCVUsedAfterCmp;
+  for (MachineInstr &Instr : instructionsWithoutDebug(
+           std::next(CmpInstr.getIterator()), CmpParent->instr_end())) {
+    if (Instr.readsRegister(AArch64::NZCV, &TRI)) {
+      AArch64CC::CondCode CC = findCondCodeUsedByInstr(Instr);
+      if (CC == AArch64CC::Invalid) // Unsupported conditional instruction
+        return None;
+      NZCVUsedAfterCmp |= getUsedNZCV(CC);
+      if (CCUseInstrs)
+        CCUseInstrs->push_back(&Instr);
+    }
+    if (Instr.modifiesRegister(AArch64::NZCV, &TRI))
+      break;
+  }
+  if (NZCVUsedAfterCmp.C || NZCVUsedAfterCmp.V)
+    return None;
+  return NZCVUsedAfterCmp;
+}
+
 static bool isADDSRegImm(unsigned Opcode) {
   return Opcode == AArch64::ADDSWri || Opcode == AArch64::ADDSXri;
 }
@@ -1646,44 +1693,21 @@ static bool isSUBSRegImm(unsigned Opcode) {
 ///        or if MI opcode is not the S form there must be neither defs of flags
 ///        nor uses of flags between MI and CmpInstr.
 /// - and  C/V flags are not used after CmpInstr
-static bool canInstrSubstituteCmpInstr(MachineInstr *MI, MachineInstr *CmpInstr,
-                                       const TargetRegisterInfo *TRI) {
-  assert(MI);
-  assert(sForm(*MI) != AArch64::INSTRUCTION_LIST_END);
-  assert(CmpInstr);
+static bool canInstrSubstituteCmpInstr(MachineInstr &MI, MachineInstr &CmpInstr,
+                                       const TargetRegisterInfo &TRI) {
+  assert(sForm(MI) != AArch64::INSTRUCTION_LIST_END);
 
-  const unsigned CmpOpcode = CmpInstr->getOpcode();
+  const unsigned CmpOpcode = CmpInstr.getOpcode();
   if (!isADDSRegImm(CmpOpcode) && !isSUBSRegImm(CmpOpcode))
     return false;
 
-  if (MI->getParent() != CmpInstr->getParent())
-    return false;
-
-  if (areCFlagsAliveInSuccessors(CmpInstr->getParent()))
+  if (!examineCFlagsUse(MI, CmpInstr, TRI))
     return false;
 
   AccessKind AccessToCheck = AK_Write;
-  if (sForm(*MI) != MI->getOpcode())
+  if (sForm(MI) != MI.getOpcode())
     AccessToCheck = AK_All;
-  if (areCFlagsAccessedBetweenInstrs(MI, CmpInstr, TRI, AccessToCheck))
-    return false;
-
-  UsedNZCV NZCVUsedAfterCmp;
-  for (const MachineInstr &Instr :
-       instructionsWithoutDebug(std::next(CmpInstr->getIterator()),
-                                CmpInstr->getParent()->instr_end())) {
-    if (Instr.readsRegister(AArch64::NZCV, TRI)) {
-      AArch64CC::CondCode CC = findCondCodeUsedByInstr(Instr);
-      if (CC == AArch64CC::Invalid) // Unsupported conditional instruction
-        return false;
-      NZCVUsedAfterCmp |= getUsedNZCV(CC);
-    }
-
-    if (Instr.modifiesRegister(AArch64::NZCV, TRI))
-      break;
-  }
-
-  return !NZCVUsedAfterCmp.C && !NZCVUsedAfterCmp.V;
+  return !areCFlagsAccessedBetweenInstrs(&MI, &CmpInstr, &TRI, AccessToCheck);
 }
 
 /// Substitute an instruction comparing to zero with another instruction
@@ -1692,20 +1716,19 @@ static bool canInstrSubstituteCmpInstr(MachineInstr *MI, MachineInstr *CmpInstr,
 /// Return true on success.
 bool AArch64InstrInfo::substituteCmpToZero(
     MachineInstr &CmpInstr, unsigned SrcReg,
-    const MachineRegisterInfo *MRI) const {
-  assert(MRI);
+    const MachineRegisterInfo &MRI) const {
   // Get the unique definition of SrcReg.
-  MachineInstr *MI = MRI->getUniqueVRegDef(SrcReg);
+  MachineInstr *MI = MRI.getUniqueVRegDef(SrcReg);
   if (!MI)
     return false;
 
-  const TargetRegisterInfo *TRI = &getRegisterInfo();
+  const TargetRegisterInfo &TRI = getRegisterInfo();
 
   unsigned NewOpc = sForm(*MI);
   if (NewOpc == AArch64::INSTRUCTION_LIST_END)
     return false;
 
-  if (!canInstrSubstituteCmpInstr(MI, &CmpInstr, TRI))
+  if (!canInstrSubstituteCmpInstr(*MI, CmpInstr, TRI))
     return false;
 
   // Update the instruction to set NZCV.
@@ -1714,7 +1737,133 @@ bool AArch64InstrInfo::substituteCmpToZero(
   bool succeeded = UpdateOperandRegClass(*MI);
   (void)succeeded;
   assert(succeeded && "Some operands reg class are incompatible!");
-  MI->addRegisterDefined(AArch64::NZCV, TRI);
+  MI->addRegisterDefined(AArch64::NZCV, &TRI);
+  return true;
+}
+
+/// \returns True if \p CmpInstr can be removed.
+///
+/// \p IsInvertCC is true if, after removing \p CmpInstr, condition
+/// codes used in \p CCUseInstrs must be inverted.
+static bool canCmpInstrBeRemoved(MachineInstr &MI, MachineInstr &CmpInstr,
+                                 int CmpValue, const TargetRegisterInfo &TRI,
+                                 SmallVectorImpl<MachineInstr *> &CCUseInstrs,
+                                 bool &IsInvertCC) {
+  assert((CmpValue == 0 || CmpValue == 1) &&
+         "Only comparisons to 0 or 1 considered for removal!");
+
+  // MI is 'CSINCWr %vreg, wzr, wzr, <cc>' or 'CSINCXr %vreg, xzr, xzr, <cc>'
+  unsigned MIOpc = MI.getOpcode();
+  if (MIOpc == AArch64::CSINCWr) {
+    if (MI.getOperand(1).getReg() != AArch64::WZR ||
+        MI.getOperand(2).getReg() != AArch64::WZR)
+      return false;
+  } else if (MIOpc == AArch64::CSINCXr) {
+    if (MI.getOperand(1).getReg() != AArch64::XZR ||
+        MI.getOperand(2).getReg() != AArch64::XZR)
+      return false;
+  } else {
+    return false;
+  }
+  AArch64CC::CondCode MICC = findCondCodeUsedByInstr(MI);
+  if (MICC == AArch64CC::Invalid)
+    return false;
+
+  // NZCV needs to be defined
+  if (MI.findRegisterDefOperandIdx(AArch64::NZCV, true) != -1)
+    return false;
+
+  // CmpInstr is 'ADDS %vreg, 0' or 'SUBS %vreg, 0' or 'SUBS %vreg, 1'
+  const unsigned CmpOpcode = CmpInstr.getOpcode();
+  bool IsSubsRegImm = isSUBSRegImm(CmpOpcode);
+  if (CmpValue && !IsSubsRegImm)
+    return false;
+  if (!CmpValue && !IsSubsRegImm && !isADDSRegImm(CmpOpcode))
+    return false;
+
+  // MI conditions allowed: eq, ne, mi, pl
+  UsedNZCV MIUsedNZCV = getUsedNZCV(MICC);
+  if (MIUsedNZCV.C || MIUsedNZCV.V)
+    return false;
+
+  Optional<UsedNZCV> NZCVUsedAfterCmp =
+      examineCFlagsUse(MI, CmpInstr, TRI, &CCUseInstrs);
+  // Condition flags are not used in CmpInstr basic block successors and only
+  // Z or N flags allowed to be used after CmpInstr within its basic block
+  if (!NZCVUsedAfterCmp)
+    return false;
+  // Z or N flag used after CmpInstr must correspond to the flag used in MI
+  if ((MIUsedNZCV.Z && NZCVUsedAfterCmp->N) ||
+      (MIUsedNZCV.N && NZCVUsedAfterCmp->Z))
+    return false;
+  // If CmpInstr is comparison to zero MI conditions are limited to eq, ne
+  if (MIUsedNZCV.N && !CmpValue)
+    return false;
+
+  // There must be no defs of flags between MI and CmpInstr
+  if (areCFlagsAccessedBetweenInstrs(&MI, &CmpInstr, &TRI, AK_Write))
+    return false;
+
+  // Condition code is inverted in the following cases:
+  // 1. MI condition is ne; CmpInstr is 'ADDS %vreg, 0' or 'SUBS %vreg, 0'
+  // 2. MI condition is eq, pl; CmpInstr is 'SUBS %vreg, 1'
+  IsInvertCC = (CmpValue && (MICC == AArch64CC::EQ || MICC == AArch64CC::PL)) ||
+               (!CmpValue && MICC == AArch64CC::NE);
+  return true;
+}
+
+/// Remove comparision in csinc-cmp sequence
+///
+/// Examples:
+/// 1. \code
+///   csinc w9, wzr, wzr, ne
+///   cmp   w9, #0
+///   b.eq
+///    \endcode
+/// to
+///    \code
+///   csinc w9, wzr, wzr, ne
+///   b.ne
+///    \endcode
+///
+/// 2. \code
+///   csinc x2, xzr, xzr, mi
+///   cmp   x2, #1
+///   b.pl
+///    \endcode
+/// to
+///    \code
+///   csinc x2, xzr, xzr, mi
+///   b.pl
+///    \endcode
+///
+/// \param  CmpInstr comparison instruction
+/// \return True when comparison removed
+bool AArch64InstrInfo::removeCmpToZeroOrOne(
+    MachineInstr &CmpInstr, unsigned SrcReg, int CmpValue,
+    const MachineRegisterInfo &MRI) const {
+  MachineInstr *MI = MRI.getUniqueVRegDef(SrcReg);
+  if (!MI)
+    return false;
+  const TargetRegisterInfo &TRI = getRegisterInfo();
+  SmallVector<MachineInstr *, 4> CCUseInstrs;
+  bool IsInvertCC = false;
+  if (!canCmpInstrBeRemoved(*MI, CmpInstr, CmpValue, TRI, CCUseInstrs,
+                            IsInvertCC))
+    return false;
+  // Make transformation
+  CmpInstr.eraseFromParent();
+  if (IsInvertCC) {
+    // Invert condition codes in CmpInstr CC users
+    for (MachineInstr *CCUseInstr : CCUseInstrs) {
+      int Idx = findCondCodeUseOperandIdxForBranchOrSelect(*CCUseInstr);
+      assert(Idx >= 0 && "Unexpected instruction using CC.");
+      MachineOperand &CCOperand = CCUseInstr->getOperand(Idx);
+      AArch64CC::CondCode CCUse = AArch64CC::getInvertedCondCode(
+          static_cast<AArch64CC::CondCode>(CCOperand.getImm()));
+      CCOperand.setImm(CCUse);
+    }
+  }
   return true;
 }
 

diff  --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.h b/llvm/lib/Target/AArch64/AArch64InstrInfo.h
index 8a724d1a1fee..29492af716be 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.h
@@ -334,7 +334,9 @@ class AArch64InstrInfo final : public AArch64GenInstrInfo {
                              MachineBasicBlock *TBB,
                              ArrayRef<MachineOperand> Cond) const;
   bool substituteCmpToZero(MachineInstr &CmpInstr, unsigned SrcReg,
-                           const MachineRegisterInfo *MRI) const;
+                           const MachineRegisterInfo &MRI) const;
+  bool removeCmpToZeroOrOne(MachineInstr &CmpInstr, unsigned SrcReg,
+                            int CmpValue, const MachineRegisterInfo &MRI) const;
 
   /// Returns an unused general-purpose register which can be used for
   /// constructing an outlined call if one exists. Returns 0 otherwise.

diff  --git a/llvm/test/CodeGen/AArch64/csinc-cmp-removal.mir b/llvm/test/CodeGen/AArch64/csinc-cmp-removal.mir
index 932c9aa21314..4222e84b113c 100644
--- a/llvm/test/CodeGen/AArch64/csinc-cmp-removal.mir
+++ b/llvm/test/CodeGen/AArch64/csinc-cmp-removal.mir
@@ -12,7 +12,6 @@ body:             |
   ; CHECK:   [[DEF:%[0-9]+]]:gpr64 = IMPLICIT_DEF
   ; CHECK:   [[SUBSXrr:%[0-9]+]]:gpr64 = SUBSXrr killed [[DEF]], [[COPY]], implicit-def $nzcv
   ; CHECK:   [[CSINCWr:%[0-9]+]]:gpr32common = CSINCWr $wzr, $wzr, 1, implicit $nzcv
-  ; CHECK:   [[SUBSWri:%[0-9]+]]:gpr32 = SUBSWri killed [[CSINCWr]], 1, 0, implicit-def $nzcv
   ; CHECK:   Bcc 1, %bb.2, implicit $nzcv
   ; CHECK:   B %bb.1
   ; CHECK: bb.1:
@@ -51,8 +50,7 @@ body:             |
   ; CHECK:   [[DEF:%[0-9]+]]:gpr64 = IMPLICIT_DEF
   ; CHECK:   [[SUBSXrr:%[0-9]+]]:gpr64 = SUBSXrr killed [[DEF]], [[COPY]], implicit-def $nzcv
   ; CHECK:   [[CSINCXr:%[0-9]+]]:gpr64common = CSINCXr $xzr, $xzr, 1, implicit $nzcv
-  ; CHECK:   [[SUBSXri:%[0-9]+]]:gpr64 = SUBSXri killed [[CSINCXr]], 0, 0, implicit-def $nzcv
-  ; CHECK:   Bcc 0, %bb.2, implicit $nzcv
+  ; CHECK:   Bcc 1, %bb.2, implicit $nzcv
   ; CHECK:   B %bb.1
   ; CHECK: bb.1:
   ; CHECK:   successors: %bb.2(0x80000000)
@@ -155,8 +153,7 @@ body:             |
   ; CHECK:   successors: %bb.1(0x40000000), %bb.2(0x40000000)
   ; CHECK:   liveins: $nzcv
   ; CHECK:   [[CSINCWr:%[0-9]+]]:gpr32common = CSINCWr $wzr, $wzr, 1, implicit $nzcv
-  ; CHECK:   [[ADDSWri:%[0-9]+]]:gpr32 = ADDSWri killed [[CSINCWr]], 0, 0, implicit-def $nzcv
-  ; CHECK:   Bcc 1, %bb.2, implicit $nzcv
+  ; CHECK:   Bcc 0, %bb.2, implicit $nzcv
   ; CHECK:   B %bb.1
   ; CHECK: bb.1:
   ; CHECK:   successors: %bb.2(0x80000000)
@@ -254,8 +251,7 @@ body:             |
   ; CHECK:   successors: %bb.1(0x40000000), %bb.2(0x40000000)
   ; CHECK:   liveins: $nzcv
   ; CHECK:   [[CSINCWr:%[0-9]+]]:gpr32common = CSINCWr $wzr, $wzr, 5, implicit $nzcv
-  ; CHECK:   [[SUBSWri:%[0-9]+]]:gpr32 = SUBSWri killed [[CSINCWr]], 1, 0, implicit-def $nzcv
-  ; CHECK:   Bcc 4, %bb.2, implicit $nzcv
+  ; CHECK:   Bcc 5, %bb.2, implicit $nzcv
   ; CHECK:   B %bb.1
   ; CHECK: bb.1:
   ; CHECK:   successors: %bb.2(0x80000000)

diff  --git a/llvm/test/CodeGen/AArch64/f16-instructions.ll b/llvm/test/CodeGen/AArch64/f16-instructions.ll
index bb445f08d1ed..802967bc8598 100644
--- a/llvm/test/CodeGen/AArch64/f16-instructions.ll
+++ b/llvm/test/CodeGen/AArch64/f16-instructions.ll
@@ -189,8 +189,6 @@ define half @test_select(half %a, half %b, i1 zeroext %c) #0 {
 ; CHECK-CVT-DAG: fcvt s1, h1
 ; CHECK-CVT-DAG: fcvt s0, h0
 ; CHECK-CVT-DAG: fcmp s2, s3
-; CHECK-CVT-DAG: cset [[CC:w[0-9]+]], ne
-; CHECK-CVT-DAG: cmp [[CC]], #0
 ; CHECK-CVT-NEXT: fcsel s0, s0, s1, ne
 ; CHECK-CVT-NEXT: fcvt h0, s0
 ; CHECK-CVT-NEXT: ret
@@ -228,8 +226,6 @@ define float @test_select_cc_f32_f16(float %a, float %b, half %c, half %d) #0 {
 ; CHECK-CVT-DAG:  fcvt s0, h0
 ; CHECK-CVT-DAG:  fcvt s1, h1
 ; CHECK-CVT-DAG:  fcmp s2, s3
-; CHECK-CVT-DAG:  cset w8, ne
-; CHECK-CVT-NEXT: cmp w8, #0
 ; CHECK-CVT-NEXT: fcsel s0, s0, s1, ne
 ; CHECK-CVT-NEXT: fcvt h0, s0
 ; CHECK-CVT-NEXT: ret


        


More information about the llvm-commits mailing list