[llvm] [ARM][AArch64] Optimize MI-PL, noswrap, or equality subtractions and additions (PR #155311)
via llvm-commits
llvm-commits at lists.llvm.org
Tue Aug 26 07:46:05 PDT 2025
https://github.com/AZero13 updated https://github.com/llvm/llvm-project/pull/155311
>From cfca67c20dcdc5731d99f2579ab4979f8de56684 Mon Sep 17 00:00:00 2001
From: AZero13 <gfunni234 at gmail.com>
Date: Mon, 25 Aug 2025 16:12:23 -0400
Subject: [PATCH 1/2] If we have MI or PL and a sub, we can just do that
instead of a CMP
Remove redundant fold regarding CSEL thanks to this too.
---
.../Target/AArch64/AArch64ISelLowering.cpp | 68 ++++++++++++-------
llvm/lib/Target/ARM/ARMISelLowering.cpp | 13 ++++
2 files changed, 55 insertions(+), 26 deletions(-)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index fbd8f7a979d66..9ac21c20a72e8 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -3586,7 +3586,8 @@ static SDValue emitStrictFPComparison(SDValue LHS, SDValue RHS, const SDLoc &DL,
}
static SDValue emitComparison(SDValue LHS, SDValue RHS, ISD::CondCode CC,
- const SDLoc &DL, SelectionDAG &DAG) {
+ const SDLoc &DL, SelectionDAG &DAG,
+ bool optimizeMIOrPL = false) {
EVT VT = LHS.getValueType();
const bool FullFP16 = DAG.getSubtarget<AArch64Subtarget>().hasFullFP16();
@@ -3630,6 +3631,44 @@ static SDValue emitComparison(SDValue LHS, SDValue RHS, ISD::CondCode CC,
// Use result of ANDS
return LHS.getValue(1);
}
+
+ if (LHS.getOpcode() == ISD::SUB) {
+ if (LHS->getFlags().hasNoSignedWrap() ||
+ ((CC == ISD::SETLT || CC == ISD::SETGE) && optimizeMIOrPL) ||
+ (CC == ISD::SETEQ || CC == ISD::SETNE)) {
+ const SDValue SUBSNode =
+ DAG.getNode(AArch64ISD::SUBS, DL, DAG.getVTList(VT, FlagsVT),
+ LHS.getOperand(0), LHS.getOperand(1));
+ // Replace all users of (and X, Y) with newly generated (ands X, Y)
+ DAG.ReplaceAllUsesWith(LHS, SUBSNode);
+ return SUBSNode.getValue(1);
+ }
+ } else if (LHS.getOpcode() == AArch64ISD::SUBS) {
+ if (LHS->getFlags().hasNoSignedWrap() ||
+ ((CC == ISD::SETLT || CC == ISD::SETGE) && optimizeMIOrPL) ||
+ (CC == ISD::SETEQ || CC == ISD::SETNE)) {
+ return LHS.getValue(1);
+ }
+ }
+
+ if (LHS.getOpcode() == ISD::ADD) {
+ if (LHS->getFlags().hasNoSignedWrap() ||
+ ((CC == ISD::SETLT || CC == ISD::SETGE) && optimizeMIOrPL) ||
+ (CC == ISD::SETEQ || CC == ISD::SETNE)) {
+ const SDValue ADDSNode =
+ DAG.getNode(AArch64ISD::ADDS, DL, DAG.getVTList(VT, FlagsVT),
+ LHS.getOperand(0), LHS.getOperand(1));
+ // Replace all users of (and X, Y) with newly generated (ands X, Y)
+ DAG.ReplaceAllUsesWith(LHS, ADDSNode);
+ return ADDSNode.getValue(1);
+ }
+ } else if (LHS.getOpcode() == AArch64ISD::ADDS) {
+ if (LHS->getFlags().hasNoSignedWrap() ||
+ ((CC == ISD::SETLT || CC == ISD::SETGE) && optimizeMIOrPL) ||
+ (CC == ISD::SETEQ || CC == ISD::SETNE)) {
+ return LHS.getValue(1);
+ }
+ }
}
return DAG.getNode(Opcode, DL, DAG.getVTList(VT, FlagsVT), LHS, RHS)
@@ -3843,7 +3882,7 @@ static SDValue emitConjunctionRec(SelectionDAG &DAG, SDValue Val,
// Produce a normal comparison if we are first in the chain
if (!CCOp)
- return emitComparison(LHS, RHS, CC, DL, DAG);
+ return emitComparison(LHS, RHS, CC, DL, DAG, isInteger);
// Otherwise produce a ccmp.
return emitConditionalComparison(LHS, RHS, CC, CCOp, Predicate, OutCC, DL,
DAG);
@@ -4125,7 +4164,7 @@ static SDValue getAArch64Cmp(SDValue LHS, SDValue RHS, ISD::CondCode CC,
}
if (!Cmp) {
- Cmp = emitComparison(LHS, RHS, CC, DL, DAG);
+ Cmp = emitComparison(LHS, RHS, CC, DL, DAG, true);
AArch64CC = changeIntCCToAArch64CC(CC, RHS);
}
AArch64cc = getCondCode(DAG, AArch64CC);
@@ -25501,29 +25540,6 @@ static SDValue performCSELCombine(SDNode *N,
}
}
- // CSEL a, b, cc, SUBS(SUB(x,y), 0) -> CSEL a, b, cc, SUBS(x,y) if cc doesn't
- // use overflow flags, to avoid the comparison with zero. In case of success,
- // this also replaces the original SUB(x,y) with the newly created SUBS(x,y).
- // NOTE: Perhaps in the future use performFlagSettingCombine to replace SUB
- // nodes with their SUBS equivalent as is already done for other flag-setting
- // operators, in which case doing the replacement here becomes redundant.
- if (Cond.getOpcode() == AArch64ISD::SUBS && Cond->hasNUsesOfValue(1, 1) &&
- isNullConstant(Cond.getOperand(1))) {
- SDValue Sub = Cond.getOperand(0);
- AArch64CC::CondCode CC =
- static_cast<AArch64CC::CondCode>(N->getConstantOperandVal(2));
- if (Sub.getOpcode() == ISD::SUB &&
- (CC == AArch64CC::EQ || CC == AArch64CC::NE || CC == AArch64CC::MI ||
- CC == AArch64CC::PL)) {
- SDLoc DL(N);
- SDValue Subs = DAG.getNode(AArch64ISD::SUBS, DL, Cond->getVTList(),
- Sub.getOperand(0), Sub.getOperand(1));
- DCI.CombineTo(Sub.getNode(), Subs);
- DCI.CombineTo(Cond.getNode(), Subs, Subs.getValue(1));
- return SDValue(N, 0);
- }
- }
-
// CSEL (LASTB P, Z), X, NE(ANY P) -> CLASTB P, X, Z
if (SDValue CondLast = foldCSELofLASTB(N, DAG))
return CondLast;
diff --git a/llvm/lib/Target/ARM/ARMISelLowering.cpp b/llvm/lib/Target/ARM/ARMISelLowering.cpp
index 12d2d678ff63a..67ff34dcd336c 100644
--- a/llvm/lib/Target/ARM/ARMISelLowering.cpp
+++ b/llvm/lib/Target/ARM/ARMISelLowering.cpp
@@ -4808,6 +4808,19 @@ SDValue ARMTargetLowering::getARMCmp(SDValue LHS, SDValue RHS, ISD::CondCode CC,
CompareType = ARMISD::CMPZ;
break;
}
+
+ // If we have MI or PL and a sub, we can just do that instead of a CMP.
+ if (CondCode == ARMCC::MI || CondCode == ARMCC::PL || CondCode == ARMCC::EQ ||
+ CondCode == ARMCC::NE ||
+ (LHS->getFlags().hasNoSignedWrap() &&
+ (CondCode == ARMCC::LT || CondCode == ARMCC::GE ||
+ CondCode == ARMCC::LE || CondCode == ARMCC::GT))) {
+ if (LHS.getOpcode() == ISD::SUB) {
+ ARMcc = DAG.getConstant(CondCode, dl, MVT::i32);
+ return DAG.getNode(CompareType, dl, FlagsVT, LHS.getOperand(0),
+ LHS.getOperand(1));
+ }
+ }
ARMcc = DAG.getConstant(CondCode, dl, MVT::i32);
return DAG.getNode(CompareType, dl, FlagsVT, LHS, RHS);
}
>From 67927c1f9cce4d8b2c267513fdbf240a705a461d Mon Sep 17 00:00:00 2001
From: AZero13 <gfunni234 at gmail.com>
Date: Tue, 26 Aug 2025 09:00:48 -0400
Subject: [PATCH 2/2] b
---
.../GISel/AArch64InstructionSelector.cpp | 30 +++++++++++++++++++
llvm/test/CodeGen/AArch64/abds-neg.ll | 15 ++++++----
2 files changed, 40 insertions(+), 5 deletions(-)
diff --git a/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp b/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp
index 0bceb322726d1..ff197c5ad352f 100644
--- a/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp
+++ b/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp
@@ -5102,6 +5102,36 @@ MachineInstr *AArch64InstructionSelector::tryFoldIntegerCompare(
return emitCMN(LHSDef->getOperand(2), RHS, MIRBuilder);
}
+ // Given this:
+ //
+ // z = G_SUB/G_ADD x, y
+ // G_ICMP z, 0
+ //
+ // Produce this if the compare is signed:
+ //
+ // cmp/cmn x, y
+ if ((LHSDef->getFlag(MachineInstr::NoSWrap) && !CmpInst::isUnsigned(P)) ||
+ (P == CmpInst::ICMP_EQ || P == CmpInst::ICMP_NE ||
+ P == CmpInst::ICMP_SLT || P == CmpInst::ICMP_SGE)) {
+
+ if (LHSDef->getOpcode() == TargetOpcode::G_SUB ||
+ LHSDef->getOpcode() == TargetOpcode::G_ADD) {
+ // Make sure that the RHS is 0.
+ auto ValAndVReg = getIConstantVRegValWithLookThrough(RHS.getReg(), MRI);
+ if (!ValAndVReg || ValAndVReg->Value != 0)
+ return nullptr;
+
+ if (LHSDef->getOpcode() == TargetOpcode::G_SUB) {
+ auto Dst = MRI.cloneVirtualRegister(LHS.getReg());
+ return emitSUBS(Dst, LHSDef->getOperand(1), LHSDef->getOperand(2),
+ MIRBuilder);
+ } else {
+ return emitCMN(LHSDef->getOperand(1), LHSDef->getOperand(2),
+ MIRBuilder);
+ }
+ }
+ }
+
// Given this:
//
// z = G_AND x, y
diff --git a/llvm/test/CodeGen/AArch64/abds-neg.ll b/llvm/test/CodeGen/AArch64/abds-neg.ll
index 02c76ba7343a0..75247823ee793 100644
--- a/llvm/test/CodeGen/AArch64/abds-neg.ll
+++ b/llvm/test/CodeGen/AArch64/abds-neg.ll
@@ -9,7 +9,8 @@ define i8 @abd_ext_i8(i8 %a, i8 %b) nounwind {
; CHECK-LABEL: abd_ext_i8:
; CHECK: // %bb.0:
; CHECK-NEXT: sxtb w8, w0
-; CHECK-NEXT: subs w8, w8, w1, sxtb
+; CHECK-NEXT: sub w8, w8, w1, sxtb
+; CHECK-NEXT: cmp w8, #0
; CHECK-NEXT: cneg w0, w8, pl
; CHECK-NEXT: ret
%aext = sext i8 %a to i64
@@ -25,7 +26,8 @@ define i8 @abd_ext_i8_i16(i8 %a, i16 %b) nounwind {
; CHECK-LABEL: abd_ext_i8_i16:
; CHECK: // %bb.0:
; CHECK-NEXT: sxtb w8, w0
-; CHECK-NEXT: subs w8, w8, w1, sxth
+; CHECK-NEXT: sub w8, w8, w1, sxth
+; CHECK-NEXT: cmp w8, #0
; CHECK-NEXT: cneg w0, w8, pl
; CHECK-NEXT: ret
%aext = sext i8 %a to i64
@@ -41,7 +43,8 @@ define i8 @abd_ext_i8_undef(i8 %a, i8 %b) nounwind {
; CHECK-LABEL: abd_ext_i8_undef:
; CHECK: // %bb.0:
; CHECK-NEXT: sxtb w8, w0
-; CHECK-NEXT: subs w8, w8, w1, sxtb
+; CHECK-NEXT: sub w8, w8, w1, sxtb
+; CHECK-NEXT: cmp w8, #0
; CHECK-NEXT: cneg w0, w8, pl
; CHECK-NEXT: ret
%aext = sext i8 %a to i64
@@ -57,7 +60,8 @@ define i16 @abd_ext_i16(i16 %a, i16 %b) nounwind {
; CHECK-LABEL: abd_ext_i16:
; CHECK: // %bb.0:
; CHECK-NEXT: sxth w8, w0
-; CHECK-NEXT: subs w8, w8, w1, sxth
+; CHECK-NEXT: sub w8, w8, w1, sxth
+; CHECK-NEXT: cmp w8, #0
; CHECK-NEXT: cneg w0, w8, pl
; CHECK-NEXT: ret
%aext = sext i16 %a to i64
@@ -89,7 +93,8 @@ define i16 @abd_ext_i16_undef(i16 %a, i16 %b) nounwind {
; CHECK-LABEL: abd_ext_i16_undef:
; CHECK: // %bb.0:
; CHECK-NEXT: sxth w8, w0
-; CHECK-NEXT: subs w8, w8, w1, sxth
+; CHECK-NEXT: sub w8, w8, w1, sxth
+; CHECK-NEXT: cmp w8, #0
; CHECK-NEXT: cneg w0, w8, pl
; CHECK-NEXT: ret
%aext = sext i16 %a to i64
More information about the llvm-commits
mailing list