[llvm] [ARM][AArch64] Optimize MI-PL, noswrap, or equality subtractions and additions (PR #155311)

via llvm-commits llvm-commits at lists.llvm.org
Tue Aug 26 07:46:05 PDT 2025


https://github.com/AZero13 updated https://github.com/llvm/llvm-project/pull/155311

>From cfca67c20dcdc5731d99f2579ab4979f8de56684 Mon Sep 17 00:00:00 2001
From: AZero13 <gfunni234 at gmail.com>
Date: Mon, 25 Aug 2025 16:12:23 -0400
Subject: [PATCH 1/2] If we have MI or PL and a sub, we can just do that
 instead of a CMP

Remove redundant fold regarding CSEL thanks to this too.
---
 .../Target/AArch64/AArch64ISelLowering.cpp    | 68 ++++++++++++-------
 llvm/lib/Target/ARM/ARMISelLowering.cpp       | 13 ++++
 2 files changed, 55 insertions(+), 26 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index fbd8f7a979d66..9ac21c20a72e8 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -3586,7 +3586,8 @@ static SDValue emitStrictFPComparison(SDValue LHS, SDValue RHS, const SDLoc &DL,
 }
 
 static SDValue emitComparison(SDValue LHS, SDValue RHS, ISD::CondCode CC,
-                              const SDLoc &DL, SelectionDAG &DAG) {
+                              const SDLoc &DL, SelectionDAG &DAG,
+                              bool optimizeMIOrPL = false) {
   EVT VT = LHS.getValueType();
   const bool FullFP16 = DAG.getSubtarget<AArch64Subtarget>().hasFullFP16();
 
@@ -3630,6 +3631,44 @@ static SDValue emitComparison(SDValue LHS, SDValue RHS, ISD::CondCode CC,
       // Use result of ANDS
       return LHS.getValue(1);
     }
+
+    if (LHS.getOpcode() == ISD::SUB) {
+      if (LHS->getFlags().hasNoSignedWrap() ||
+          ((CC == ISD::SETLT || CC == ISD::SETGE) && optimizeMIOrPL) ||
+          (CC == ISD::SETEQ || CC == ISD::SETNE)) {
+        const SDValue SUBSNode =
+            DAG.getNode(AArch64ISD::SUBS, DL, DAG.getVTList(VT, FlagsVT),
+                        LHS.getOperand(0), LHS.getOperand(1));
+        // Replace all users of (and X, Y) with newly generated (ands X, Y)
+        DAG.ReplaceAllUsesWith(LHS, SUBSNode);
+        return SUBSNode.getValue(1);
+      }
+    } else if (LHS.getOpcode() == AArch64ISD::SUBS) {
+      if (LHS->getFlags().hasNoSignedWrap() ||
+          ((CC == ISD::SETLT || CC == ISD::SETGE) && optimizeMIOrPL) ||
+          (CC == ISD::SETEQ || CC == ISD::SETNE)) {
+        return LHS.getValue(1);
+      }
+    }
+
+    if (LHS.getOpcode() == ISD::ADD) {
+      if (LHS->getFlags().hasNoSignedWrap() ||
+          ((CC == ISD::SETLT || CC == ISD::SETGE) && optimizeMIOrPL) ||
+          (CC == ISD::SETEQ || CC == ISD::SETNE)) {
+        const SDValue ADDSNode =
+            DAG.getNode(AArch64ISD::ADDS, DL, DAG.getVTList(VT, FlagsVT),
+                        LHS.getOperand(0), LHS.getOperand(1));
+        // Replace all users of (and X, Y) with newly generated (ands X, Y)
+        DAG.ReplaceAllUsesWith(LHS, ADDSNode);
+        return ADDSNode.getValue(1);
+      }
+    } else if (LHS.getOpcode() == AArch64ISD::ADDS) {
+      if (LHS->getFlags().hasNoSignedWrap() ||
+          ((CC == ISD::SETLT || CC == ISD::SETGE) && optimizeMIOrPL) ||
+          (CC == ISD::SETEQ || CC == ISD::SETNE)) {
+        return LHS.getValue(1);
+      }
+    }
   }
 
   return DAG.getNode(Opcode, DL, DAG.getVTList(VT, FlagsVT), LHS, RHS)
@@ -3843,7 +3882,7 @@ static SDValue emitConjunctionRec(SelectionDAG &DAG, SDValue Val,
 
     // Produce a normal comparison if we are first in the chain
     if (!CCOp)
-      return emitComparison(LHS, RHS, CC, DL, DAG);
+      return emitComparison(LHS, RHS, CC, DL, DAG, isInteger);
     // Otherwise produce a ccmp.
     return emitConditionalComparison(LHS, RHS, CC, CCOp, Predicate, OutCC, DL,
                                      DAG);
@@ -4125,7 +4164,7 @@ static SDValue getAArch64Cmp(SDValue LHS, SDValue RHS, ISD::CondCode CC,
   }
 
   if (!Cmp) {
-    Cmp = emitComparison(LHS, RHS, CC, DL, DAG);
+    Cmp = emitComparison(LHS, RHS, CC, DL, DAG, true);
     AArch64CC = changeIntCCToAArch64CC(CC, RHS);
   }
   AArch64cc = getCondCode(DAG, AArch64CC);
@@ -25501,29 +25540,6 @@ static SDValue performCSELCombine(SDNode *N,
     }
   }
 
-  // CSEL a, b, cc, SUBS(SUB(x,y), 0) -> CSEL a, b, cc, SUBS(x,y) if cc doesn't
-  // use overflow flags, to avoid the comparison with zero. In case of success,
-  // this also replaces the original SUB(x,y) with the newly created SUBS(x,y).
-  // NOTE: Perhaps in the future use performFlagSettingCombine to replace SUB
-  // nodes with their SUBS equivalent as is already done for other flag-setting
-  // operators, in which case doing the replacement here becomes redundant.
-  if (Cond.getOpcode() == AArch64ISD::SUBS && Cond->hasNUsesOfValue(1, 1) &&
-      isNullConstant(Cond.getOperand(1))) {
-    SDValue Sub = Cond.getOperand(0);
-    AArch64CC::CondCode CC =
-        static_cast<AArch64CC::CondCode>(N->getConstantOperandVal(2));
-    if (Sub.getOpcode() == ISD::SUB &&
-        (CC == AArch64CC::EQ || CC == AArch64CC::NE || CC == AArch64CC::MI ||
-         CC == AArch64CC::PL)) {
-      SDLoc DL(N);
-      SDValue Subs = DAG.getNode(AArch64ISD::SUBS, DL, Cond->getVTList(),
-                                 Sub.getOperand(0), Sub.getOperand(1));
-      DCI.CombineTo(Sub.getNode(), Subs);
-      DCI.CombineTo(Cond.getNode(), Subs, Subs.getValue(1));
-      return SDValue(N, 0);
-    }
-  }
-
   // CSEL (LASTB P, Z), X, NE(ANY P) -> CLASTB P, X, Z
   if (SDValue CondLast = foldCSELofLASTB(N, DAG))
     return CondLast;
diff --git a/llvm/lib/Target/ARM/ARMISelLowering.cpp b/llvm/lib/Target/ARM/ARMISelLowering.cpp
index 12d2d678ff63a..67ff34dcd336c 100644
--- a/llvm/lib/Target/ARM/ARMISelLowering.cpp
+++ b/llvm/lib/Target/ARM/ARMISelLowering.cpp
@@ -4808,6 +4808,19 @@ SDValue ARMTargetLowering::getARMCmp(SDValue LHS, SDValue RHS, ISD::CondCode CC,
     CompareType = ARMISD::CMPZ;
     break;
   }
+
+  // If we have MI or PL and a sub, we can just do that instead of a CMP.
+  if (CondCode == ARMCC::MI || CondCode == ARMCC::PL || CondCode == ARMCC::EQ ||
+      CondCode == ARMCC::NE ||
+      (LHS->getFlags().hasNoSignedWrap() &&
+       (CondCode == ARMCC::LT || CondCode == ARMCC::GE ||
+        CondCode == ARMCC::LE || CondCode == ARMCC::GT))) {
+    if (LHS.getOpcode() == ISD::SUB) {
+      ARMcc = DAG.getConstant(CondCode, dl, MVT::i32);
+      return DAG.getNode(CompareType, dl, FlagsVT, LHS.getOperand(0),
+                         LHS.getOperand(1));
+    }
+  }
   ARMcc = DAG.getConstant(CondCode, dl, MVT::i32);
   return DAG.getNode(CompareType, dl, FlagsVT, LHS, RHS);
 }

>From 67927c1f9cce4d8b2c267513fdbf240a705a461d Mon Sep 17 00:00:00 2001
From: AZero13 <gfunni234 at gmail.com>
Date: Tue, 26 Aug 2025 09:00:48 -0400
Subject: [PATCH 2/2] b

---
 .../GISel/AArch64InstructionSelector.cpp      | 30 +++++++++++++++++++
 llvm/test/CodeGen/AArch64/abds-neg.ll         | 15 ++++++----
 2 files changed, 40 insertions(+), 5 deletions(-)

diff --git a/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp b/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp
index 0bceb322726d1..ff197c5ad352f 100644
--- a/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp
+++ b/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp
@@ -5102,6 +5102,36 @@ MachineInstr *AArch64InstructionSelector::tryFoldIntegerCompare(
     return emitCMN(LHSDef->getOperand(2), RHS, MIRBuilder);
   }
 
+  // Given this:
+  //
+  // z = G_SUB/G_ADD x, y
+  // G_ICMP z, 0
+  //
+  // Produce this if the compare is signed:
+  //
+  // cmp/cmn x, y
+  if ((LHSDef->getFlag(MachineInstr::NoSWrap) && !CmpInst::isUnsigned(P)) ||
+      (P == CmpInst::ICMP_EQ || P == CmpInst::ICMP_NE ||
+       P == CmpInst::ICMP_SLT || P == CmpInst::ICMP_SGE)) {
+
+    if (LHSDef->getOpcode() == TargetOpcode::G_SUB ||
+        LHSDef->getOpcode() == TargetOpcode::G_ADD) {
+      // Make sure that the RHS is 0.
+      auto ValAndVReg = getIConstantVRegValWithLookThrough(RHS.getReg(), MRI);
+      if (!ValAndVReg || ValAndVReg->Value != 0)
+        return nullptr;
+
+      if (LHSDef->getOpcode() == TargetOpcode::G_SUB) {
+        auto Dst = MRI.cloneVirtualRegister(LHS.getReg());
+        return emitSUBS(Dst, LHSDef->getOperand(1), LHSDef->getOperand(2),
+                        MIRBuilder);
+      } else {
+        return emitCMN(LHSDef->getOperand(1), LHSDef->getOperand(2),
+                       MIRBuilder);
+      }
+    }
+  }
+
   // Given this:
   //
   // z = G_AND x, y
diff --git a/llvm/test/CodeGen/AArch64/abds-neg.ll b/llvm/test/CodeGen/AArch64/abds-neg.ll
index 02c76ba7343a0..75247823ee793 100644
--- a/llvm/test/CodeGen/AArch64/abds-neg.ll
+++ b/llvm/test/CodeGen/AArch64/abds-neg.ll
@@ -9,7 +9,8 @@ define i8 @abd_ext_i8(i8 %a, i8 %b) nounwind {
 ; CHECK-LABEL: abd_ext_i8:
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    sxtb w8, w0
-; CHECK-NEXT:    subs w8, w8, w1, sxtb
+; CHECK-NEXT:    sub w8, w8, w1, sxtb
+; CHECK-NEXT:    cmp w8, #0
 ; CHECK-NEXT:    cneg w0, w8, pl
 ; CHECK-NEXT:    ret
   %aext = sext i8 %a to i64
@@ -25,7 +26,8 @@ define i8 @abd_ext_i8_i16(i8 %a, i16 %b) nounwind {
 ; CHECK-LABEL: abd_ext_i8_i16:
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    sxtb w8, w0
-; CHECK-NEXT:    subs w8, w8, w1, sxth
+; CHECK-NEXT:    sub w8, w8, w1, sxth
+; CHECK-NEXT:    cmp w8, #0
 ; CHECK-NEXT:    cneg w0, w8, pl
 ; CHECK-NEXT:    ret
   %aext = sext i8 %a to i64
@@ -41,7 +43,8 @@ define i8 @abd_ext_i8_undef(i8 %a, i8 %b) nounwind {
 ; CHECK-LABEL: abd_ext_i8_undef:
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    sxtb w8, w0
-; CHECK-NEXT:    subs w8, w8, w1, sxtb
+; CHECK-NEXT:    sub w8, w8, w1, sxtb
+; CHECK-NEXT:    cmp w8, #0
 ; CHECK-NEXT:    cneg w0, w8, pl
 ; CHECK-NEXT:    ret
   %aext = sext i8 %a to i64
@@ -57,7 +60,8 @@ define i16 @abd_ext_i16(i16 %a, i16 %b) nounwind {
 ; CHECK-LABEL: abd_ext_i16:
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    sxth w8, w0
-; CHECK-NEXT:    subs w8, w8, w1, sxth
+; CHECK-NEXT:    sub w8, w8, w1, sxth
+; CHECK-NEXT:    cmp w8, #0
 ; CHECK-NEXT:    cneg w0, w8, pl
 ; CHECK-NEXT:    ret
   %aext = sext i16 %a to i64
@@ -89,7 +93,8 @@ define i16 @abd_ext_i16_undef(i16 %a, i16 %b) nounwind {
 ; CHECK-LABEL: abd_ext_i16_undef:
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    sxth w8, w0
-; CHECK-NEXT:    subs w8, w8, w1, sxth
+; CHECK-NEXT:    sub w8, w8, w1, sxth
+; CHECK-NEXT:    cmp w8, #0
 ; CHECK-NEXT:    cneg w0, w8, pl
 ; CHECK-NEXT:    ret
   %aext = sext i16 %a to i64



More information about the llvm-commits mailing list