[llvm] 2ec1610 - [AArch64] Peephole rule to remove redundant cmp after cset.
Pavel Iliin via llvm-commits
llvm-commits at lists.llvm.org
Mon Apr 19 11:59:08 PDT 2021
Author: Pavel Iliin
Date: 2021-04-19T19:58:38+01:00
New Revision: 2ec16103c68528669080040629961217662353cd
URL: https://github.com/llvm/llvm-project/commit/2ec16103c68528669080040629961217662353cd
DIFF: https://github.com/llvm/llvm-project/commit/2ec16103c68528669080040629961217662353cd.diff
LOG: [AArch64] Peephole rule to remove redundant cmp after cset.
Comparisons to zero or one after cset instructions can be safely
removed in examples like:
cset w9, eq cset w9, eq
cmp w9, #1 ---> <removed>
b.ne .L1 b.ne .L1
cset w9, eq cset w9, eq
cmp w9, #0 ---> <removed>
b.ne .L1 b.eq .L1
Peephole optimization to detect suitable cases and get rid of that
comparisons added.
Differential Revision: https://reviews.llvm.org/D98564
Added:
Modified:
llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
llvm/lib/Target/AArch64/AArch64InstrInfo.h
llvm/test/CodeGen/AArch64/csinc-cmp-removal.mir
llvm/test/CodeGen/AArch64/f16-instructions.ll
Removed:
################################################################################
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
index 64adc973beeb..28edf104fc0b 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
@@ -1463,14 +1463,16 @@ bool AArch64InstrInfo::optimizeCompareInstr(
// FIXME:CmpValue has already been converted to 0 or 1 in analyzeCompare
// function.
assert((CmpValue == 0 || CmpValue == 1) && "CmpValue must be 0 or 1!");
- if (CmpValue != 0 || SrcReg2 != 0)
+ if (SrcReg2 != 0)
return false;
// CmpInstr is a Compare instruction if destination register is not used.
if (!MRI->use_nodbg_empty(CmpInstr.getOperand(0).getReg()))
return false;
- return substituteCmpToZero(CmpInstr, SrcReg, MRI);
+ if (!CmpValue && substituteCmpToZero(CmpInstr, SrcReg, *MRI))
+ return true;
+ return removeCmpToZeroOrOne(CmpInstr, SrcReg, CmpValue, *MRI);
}
/// Get opcode of S version of Instr.
@@ -1524,13 +1526,44 @@ static unsigned sForm(MachineInstr &Instr) {
}
/// Check if AArch64::NZCV should be alive in successors of MBB.
-static bool areCFlagsAliveInSuccessors(MachineBasicBlock *MBB) {
+static bool areCFlagsAliveInSuccessors(const MachineBasicBlock *MBB) {
for (auto *BB : MBB->successors())
if (BB->isLiveIn(AArch64::NZCV))
return true;
return false;
}
+/// \returns The condition code operand index for \p Instr if it is a branch
+/// or select and -1 otherwise.
+static int
+findCondCodeUseOperandIdxForBranchOrSelect(const MachineInstr &Instr) {
+ switch (Instr.getOpcode()) {
+ default:
+ return -1;
+
+ case AArch64::Bcc: {
+ int Idx = Instr.findRegisterUseOperandIdx(AArch64::NZCV);
+ assert(Idx >= 2);
+ return Idx - 2;
+ }
+
+ case AArch64::CSINVWr:
+ case AArch64::CSINVXr:
+ case AArch64::CSINCWr:
+ case AArch64::CSINCXr:
+ case AArch64::CSELWr:
+ case AArch64::CSELXr:
+ case AArch64::CSNEGWr:
+ case AArch64::CSNEGXr:
+ case AArch64::FCSELSrrr:
+ case AArch64::FCSELDrrr: {
+ int Idx = Instr.findRegisterUseOperandIdx(AArch64::NZCV);
+ assert(Idx >= 1);
+ return Idx - 1;
+ }
+ }
+}
+
namespace {
struct UsedNZCV {
@@ -1556,31 +1589,10 @@ struct UsedNZCV {
/// Returns AArch64CC::Invalid if either the instruction does not use condition
/// codes or we don't optimize CmpInstr in the presence of such instructions.
static AArch64CC::CondCode findCondCodeUsedByInstr(const MachineInstr &Instr) {
- switch (Instr.getOpcode()) {
- default:
- return AArch64CC::Invalid;
-
- case AArch64::Bcc: {
- int Idx = Instr.findRegisterUseOperandIdx(AArch64::NZCV);
- assert(Idx >= 2);
- return static_cast<AArch64CC::CondCode>(Instr.getOperand(Idx - 2).getImm());
- }
-
- case AArch64::CSINVWr:
- case AArch64::CSINVXr:
- case AArch64::CSINCWr:
- case AArch64::CSINCXr:
- case AArch64::CSELWr:
- case AArch64::CSELXr:
- case AArch64::CSNEGWr:
- case AArch64::CSNEGXr:
- case AArch64::FCSELSrrr:
- case AArch64::FCSELDrrr: {
- int Idx = Instr.findRegisterUseOperandIdx(AArch64::NZCV);
- assert(Idx >= 1);
- return static_cast<AArch64CC::CondCode>(Instr.getOperand(Idx - 1).getImm());
- }
- }
+ int CCIdx = findCondCodeUseOperandIdxForBranchOrSelect(Instr);
+ return CCIdx >= 0 ? static_cast<AArch64CC::CondCode>(
+ Instr.getOperand(CCIdx).getImm())
+ : AArch64CC::Invalid;
}
static UsedNZCV getUsedNZCV(AArch64CC::CondCode CC) {
@@ -1627,6 +1639,41 @@ static UsedNZCV getUsedNZCV(AArch64CC::CondCode CC) {
return UsedFlags;
}
+/// \returns Conditions flags used after \p CmpInstr in its MachineBB if they
+/// are not containing C or V flags and NZCV flags are not alive in successors
+/// of the same \p CmpInstr and \p MI parent. \returns None otherwise.
+///
+/// Collect instructions using that flags in \p CCUseInstrs if provided.
+static Optional<UsedNZCV>
+examineCFlagsUse(MachineInstr &MI, MachineInstr &CmpInstr,
+ const TargetRegisterInfo &TRI,
+ SmallVectorImpl<MachineInstr *> *CCUseInstrs = nullptr) {
+ MachineBasicBlock *CmpParent = CmpInstr.getParent();
+ if (MI.getParent() != CmpParent)
+ return None;
+
+ if (areCFlagsAliveInSuccessors(CmpParent))
+ return None;
+
+ UsedNZCV NZCVUsedAfterCmp;
+ for (MachineInstr &Instr : instructionsWithoutDebug(
+ std::next(CmpInstr.getIterator()), CmpParent->instr_end())) {
+ if (Instr.readsRegister(AArch64::NZCV, &TRI)) {
+ AArch64CC::CondCode CC = findCondCodeUsedByInstr(Instr);
+ if (CC == AArch64CC::Invalid) // Unsupported conditional instruction
+ return None;
+ NZCVUsedAfterCmp |= getUsedNZCV(CC);
+ if (CCUseInstrs)
+ CCUseInstrs->push_back(&Instr);
+ }
+ if (Instr.modifiesRegister(AArch64::NZCV, &TRI))
+ break;
+ }
+ if (NZCVUsedAfterCmp.C || NZCVUsedAfterCmp.V)
+ return None;
+ return NZCVUsedAfterCmp;
+}
+
static bool isADDSRegImm(unsigned Opcode) {
return Opcode == AArch64::ADDSWri || Opcode == AArch64::ADDSXri;
}
@@ -1646,44 +1693,21 @@ static bool isSUBSRegImm(unsigned Opcode) {
/// or if MI opcode is not the S form there must be neither defs of flags
/// nor uses of flags between MI and CmpInstr.
/// - and C/V flags are not used after CmpInstr
-static bool canInstrSubstituteCmpInstr(MachineInstr *MI, MachineInstr *CmpInstr,
- const TargetRegisterInfo *TRI) {
- assert(MI);
- assert(sForm(*MI) != AArch64::INSTRUCTION_LIST_END);
- assert(CmpInstr);
+static bool canInstrSubstituteCmpInstr(MachineInstr &MI, MachineInstr &CmpInstr,
+ const TargetRegisterInfo &TRI) {
+ assert(sForm(MI) != AArch64::INSTRUCTION_LIST_END);
- const unsigned CmpOpcode = CmpInstr->getOpcode();
+ const unsigned CmpOpcode = CmpInstr.getOpcode();
if (!isADDSRegImm(CmpOpcode) && !isSUBSRegImm(CmpOpcode))
return false;
- if (MI->getParent() != CmpInstr->getParent())
- return false;
-
- if (areCFlagsAliveInSuccessors(CmpInstr->getParent()))
+ if (!examineCFlagsUse(MI, CmpInstr, TRI))
return false;
AccessKind AccessToCheck = AK_Write;
- if (sForm(*MI) != MI->getOpcode())
+ if (sForm(MI) != MI.getOpcode())
AccessToCheck = AK_All;
- if (areCFlagsAccessedBetweenInstrs(MI, CmpInstr, TRI, AccessToCheck))
- return false;
-
- UsedNZCV NZCVUsedAfterCmp;
- for (const MachineInstr &Instr :
- instructionsWithoutDebug(std::next(CmpInstr->getIterator()),
- CmpInstr->getParent()->instr_end())) {
- if (Instr.readsRegister(AArch64::NZCV, TRI)) {
- AArch64CC::CondCode CC = findCondCodeUsedByInstr(Instr);
- if (CC == AArch64CC::Invalid) // Unsupported conditional instruction
- return false;
- NZCVUsedAfterCmp |= getUsedNZCV(CC);
- }
-
- if (Instr.modifiesRegister(AArch64::NZCV, TRI))
- break;
- }
-
- return !NZCVUsedAfterCmp.C && !NZCVUsedAfterCmp.V;
+ return !areCFlagsAccessedBetweenInstrs(&MI, &CmpInstr, &TRI, AccessToCheck);
}
/// Substitute an instruction comparing to zero with another instruction
@@ -1692,20 +1716,19 @@ static bool canInstrSubstituteCmpInstr(MachineInstr *MI, MachineInstr *CmpInstr,
/// Return true on success.
bool AArch64InstrInfo::substituteCmpToZero(
MachineInstr &CmpInstr, unsigned SrcReg,
- const MachineRegisterInfo *MRI) const {
- assert(MRI);
+ const MachineRegisterInfo &MRI) const {
// Get the unique definition of SrcReg.
- MachineInstr *MI = MRI->getUniqueVRegDef(SrcReg);
+ MachineInstr *MI = MRI.getUniqueVRegDef(SrcReg);
if (!MI)
return false;
- const TargetRegisterInfo *TRI = &getRegisterInfo();
+ const TargetRegisterInfo &TRI = getRegisterInfo();
unsigned NewOpc = sForm(*MI);
if (NewOpc == AArch64::INSTRUCTION_LIST_END)
return false;
- if (!canInstrSubstituteCmpInstr(MI, &CmpInstr, TRI))
+ if (!canInstrSubstituteCmpInstr(*MI, CmpInstr, TRI))
return false;
// Update the instruction to set NZCV.
@@ -1714,7 +1737,133 @@ bool AArch64InstrInfo::substituteCmpToZero(
bool succeeded = UpdateOperandRegClass(*MI);
(void)succeeded;
assert(succeeded && "Some operands reg class are incompatible!");
- MI->addRegisterDefined(AArch64::NZCV, TRI);
+ MI->addRegisterDefined(AArch64::NZCV, &TRI);
+ return true;
+}
+
+/// \returns True if \p CmpInstr can be removed.
+///
+/// \p IsInvertCC is true if, after removing \p CmpInstr, condition
+/// codes used in \p CCUseInstrs must be inverted.
+static bool canCmpInstrBeRemoved(MachineInstr &MI, MachineInstr &CmpInstr,
+ int CmpValue, const TargetRegisterInfo &TRI,
+ SmallVectorImpl<MachineInstr *> &CCUseInstrs,
+ bool &IsInvertCC) {
+ assert((CmpValue == 0 || CmpValue == 1) &&
+ "Only comparisons to 0 or 1 considered for removal!");
+
+ // MI is 'CSINCWr %vreg, wzr, wzr, <cc>' or 'CSINCXr %vreg, xzr, xzr, <cc>'
+ unsigned MIOpc = MI.getOpcode();
+ if (MIOpc == AArch64::CSINCWr) {
+ if (MI.getOperand(1).getReg() != AArch64::WZR ||
+ MI.getOperand(2).getReg() != AArch64::WZR)
+ return false;
+ } else if (MIOpc == AArch64::CSINCXr) {
+ if (MI.getOperand(1).getReg() != AArch64::XZR ||
+ MI.getOperand(2).getReg() != AArch64::XZR)
+ return false;
+ } else {
+ return false;
+ }
+ AArch64CC::CondCode MICC = findCondCodeUsedByInstr(MI);
+ if (MICC == AArch64CC::Invalid)
+ return false;
+
+ // NZCV needs to be defined
+ if (MI.findRegisterDefOperandIdx(AArch64::NZCV, true) != -1)
+ return false;
+
+ // CmpInstr is 'ADDS %vreg, 0' or 'SUBS %vreg, 0' or 'SUBS %vreg, 1'
+ const unsigned CmpOpcode = CmpInstr.getOpcode();
+ bool IsSubsRegImm = isSUBSRegImm(CmpOpcode);
+ if (CmpValue && !IsSubsRegImm)
+ return false;
+ if (!CmpValue && !IsSubsRegImm && !isADDSRegImm(CmpOpcode))
+ return false;
+
+ // MI conditions allowed: eq, ne, mi, pl
+ UsedNZCV MIUsedNZCV = getUsedNZCV(MICC);
+ if (MIUsedNZCV.C || MIUsedNZCV.V)
+ return false;
+
+ Optional<UsedNZCV> NZCVUsedAfterCmp =
+ examineCFlagsUse(MI, CmpInstr, TRI, &CCUseInstrs);
+ // Condition flags are not used in CmpInstr basic block successors and only
+ // Z or N flags allowed to be used after CmpInstr within its basic block
+ if (!NZCVUsedAfterCmp)
+ return false;
+ // Z or N flag used after CmpInstr must correspond to the flag used in MI
+ if ((MIUsedNZCV.Z && NZCVUsedAfterCmp->N) ||
+ (MIUsedNZCV.N && NZCVUsedAfterCmp->Z))
+ return false;
+ // If CmpInstr is comparison to zero MI conditions are limited to eq, ne
+ if (MIUsedNZCV.N && !CmpValue)
+ return false;
+
+ // There must be no defs of flags between MI and CmpInstr
+ if (areCFlagsAccessedBetweenInstrs(&MI, &CmpInstr, &TRI, AK_Write))
+ return false;
+
+ // Condition code is inverted in the following cases:
+ // 1. MI condition is ne; CmpInstr is 'ADDS %vreg, 0' or 'SUBS %vreg, 0'
+ // 2. MI condition is eq, pl; CmpInstr is 'SUBS %vreg, 1'
+ IsInvertCC = (CmpValue && (MICC == AArch64CC::EQ || MICC == AArch64CC::PL)) ||
+ (!CmpValue && MICC == AArch64CC::NE);
+ return true;
+}
+
+/// Remove comparision in csinc-cmp sequence
+///
+/// Examples:
+/// 1. \code
+/// csinc w9, wzr, wzr, ne
+/// cmp w9, #0
+/// b.eq
+/// \endcode
+/// to
+/// \code
+/// csinc w9, wzr, wzr, ne
+/// b.ne
+/// \endcode
+///
+/// 2. \code
+/// csinc x2, xzr, xzr, mi
+/// cmp x2, #1
+/// b.pl
+/// \endcode
+/// to
+/// \code
+/// csinc x2, xzr, xzr, mi
+/// b.pl
+/// \endcode
+///
+/// \param CmpInstr comparison instruction
+/// \return True when comparison removed
+bool AArch64InstrInfo::removeCmpToZeroOrOne(
+ MachineInstr &CmpInstr, unsigned SrcReg, int CmpValue,
+ const MachineRegisterInfo &MRI) const {
+ MachineInstr *MI = MRI.getUniqueVRegDef(SrcReg);
+ if (!MI)
+ return false;
+ const TargetRegisterInfo &TRI = getRegisterInfo();
+ SmallVector<MachineInstr *, 4> CCUseInstrs;
+ bool IsInvertCC = false;
+ if (!canCmpInstrBeRemoved(*MI, CmpInstr, CmpValue, TRI, CCUseInstrs,
+ IsInvertCC))
+ return false;
+ // Make transformation
+ CmpInstr.eraseFromParent();
+ if (IsInvertCC) {
+ // Invert condition codes in CmpInstr CC users
+ for (MachineInstr *CCUseInstr : CCUseInstrs) {
+ int Idx = findCondCodeUseOperandIdxForBranchOrSelect(*CCUseInstr);
+ assert(Idx >= 0 && "Unexpected instruction using CC.");
+ MachineOperand &CCOperand = CCUseInstr->getOperand(Idx);
+ AArch64CC::CondCode CCUse = AArch64CC::getInvertedCondCode(
+ static_cast<AArch64CC::CondCode>(CCOperand.getImm()));
+ CCOperand.setImm(CCUse);
+ }
+ }
return true;
}
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.h b/llvm/lib/Target/AArch64/AArch64InstrInfo.h
index 8a724d1a1fee..29492af716be 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.h
@@ -334,7 +334,9 @@ class AArch64InstrInfo final : public AArch64GenInstrInfo {
MachineBasicBlock *TBB,
ArrayRef<MachineOperand> Cond) const;
bool substituteCmpToZero(MachineInstr &CmpInstr, unsigned SrcReg,
- const MachineRegisterInfo *MRI) const;
+ const MachineRegisterInfo &MRI) const;
+ bool removeCmpToZeroOrOne(MachineInstr &CmpInstr, unsigned SrcReg,
+ int CmpValue, const MachineRegisterInfo &MRI) const;
/// Returns an unused general-purpose register which can be used for
/// constructing an outlined call if one exists. Returns 0 otherwise.
diff --git a/llvm/test/CodeGen/AArch64/csinc-cmp-removal.mir b/llvm/test/CodeGen/AArch64/csinc-cmp-removal.mir
index 932c9aa21314..4222e84b113c 100644
--- a/llvm/test/CodeGen/AArch64/csinc-cmp-removal.mir
+++ b/llvm/test/CodeGen/AArch64/csinc-cmp-removal.mir
@@ -12,7 +12,6 @@ body: |
; CHECK: [[DEF:%[0-9]+]]:gpr64 = IMPLICIT_DEF
; CHECK: [[SUBSXrr:%[0-9]+]]:gpr64 = SUBSXrr killed [[DEF]], [[COPY]], implicit-def $nzcv
; CHECK: [[CSINCWr:%[0-9]+]]:gpr32common = CSINCWr $wzr, $wzr, 1, implicit $nzcv
- ; CHECK: [[SUBSWri:%[0-9]+]]:gpr32 = SUBSWri killed [[CSINCWr]], 1, 0, implicit-def $nzcv
; CHECK: Bcc 1, %bb.2, implicit $nzcv
; CHECK: B %bb.1
; CHECK: bb.1:
@@ -51,8 +50,7 @@ body: |
; CHECK: [[DEF:%[0-9]+]]:gpr64 = IMPLICIT_DEF
; CHECK: [[SUBSXrr:%[0-9]+]]:gpr64 = SUBSXrr killed [[DEF]], [[COPY]], implicit-def $nzcv
; CHECK: [[CSINCXr:%[0-9]+]]:gpr64common = CSINCXr $xzr, $xzr, 1, implicit $nzcv
- ; CHECK: [[SUBSXri:%[0-9]+]]:gpr64 = SUBSXri killed [[CSINCXr]], 0, 0, implicit-def $nzcv
- ; CHECK: Bcc 0, %bb.2, implicit $nzcv
+ ; CHECK: Bcc 1, %bb.2, implicit $nzcv
; CHECK: B %bb.1
; CHECK: bb.1:
; CHECK: successors: %bb.2(0x80000000)
@@ -155,8 +153,7 @@ body: |
; CHECK: successors: %bb.1(0x40000000), %bb.2(0x40000000)
; CHECK: liveins: $nzcv
; CHECK: [[CSINCWr:%[0-9]+]]:gpr32common = CSINCWr $wzr, $wzr, 1, implicit $nzcv
- ; CHECK: [[ADDSWri:%[0-9]+]]:gpr32 = ADDSWri killed [[CSINCWr]], 0, 0, implicit-def $nzcv
- ; CHECK: Bcc 1, %bb.2, implicit $nzcv
+ ; CHECK: Bcc 0, %bb.2, implicit $nzcv
; CHECK: B %bb.1
; CHECK: bb.1:
; CHECK: successors: %bb.2(0x80000000)
@@ -254,8 +251,7 @@ body: |
; CHECK: successors: %bb.1(0x40000000), %bb.2(0x40000000)
; CHECK: liveins: $nzcv
; CHECK: [[CSINCWr:%[0-9]+]]:gpr32common = CSINCWr $wzr, $wzr, 5, implicit $nzcv
- ; CHECK: [[SUBSWri:%[0-9]+]]:gpr32 = SUBSWri killed [[CSINCWr]], 1, 0, implicit-def $nzcv
- ; CHECK: Bcc 4, %bb.2, implicit $nzcv
+ ; CHECK: Bcc 5, %bb.2, implicit $nzcv
; CHECK: B %bb.1
; CHECK: bb.1:
; CHECK: successors: %bb.2(0x80000000)
diff --git a/llvm/test/CodeGen/AArch64/f16-instructions.ll b/llvm/test/CodeGen/AArch64/f16-instructions.ll
index bb445f08d1ed..802967bc8598 100644
--- a/llvm/test/CodeGen/AArch64/f16-instructions.ll
+++ b/llvm/test/CodeGen/AArch64/f16-instructions.ll
@@ -189,8 +189,6 @@ define half @test_select(half %a, half %b, i1 zeroext %c) #0 {
; CHECK-CVT-DAG: fcvt s1, h1
; CHECK-CVT-DAG: fcvt s0, h0
; CHECK-CVT-DAG: fcmp s2, s3
-; CHECK-CVT-DAG: cset [[CC:w[0-9]+]], ne
-; CHECK-CVT-DAG: cmp [[CC]], #0
; CHECK-CVT-NEXT: fcsel s0, s0, s1, ne
; CHECK-CVT-NEXT: fcvt h0, s0
; CHECK-CVT-NEXT: ret
@@ -228,8 +226,6 @@ define float @test_select_cc_f32_f16(float %a, float %b, half %c, half %d) #0 {
; CHECK-CVT-DAG: fcvt s0, h0
; CHECK-CVT-DAG: fcvt s1, h1
; CHECK-CVT-DAG: fcmp s2, s3
-; CHECK-CVT-DAG: cset w8, ne
-; CHECK-CVT-NEXT: cmp w8, #0
; CHECK-CVT-NEXT: fcsel s0, s0, s1, ne
; CHECK-CVT-NEXT: fcvt h0, s0
; CHECK-CVT-NEXT: ret
More information about the llvm-commits
mailing list