[llvm] [ARM][AArch64] Optimize MI-PL, noswrap, or equality subtractions and additions (PR #155311)

via llvm-commits llvm-commits at lists.llvm.org
Mon Aug 25 14:44:04 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-arm

Author: AZero13 (AZero13)

<details>
<summary>Changes</summary>



---
Full diff: https://github.com/llvm/llvm-project/pull/155311.diff


3 Files Affected:

- (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+38-3) 
- (modified) llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp (+31) 
- (modified) llvm/lib/Target/ARM/ARMISelLowering.cpp (+13) 


``````````diff
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index fbd8f7a979d66..3826ad61a5823 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -3586,7 +3586,8 @@ static SDValue emitStrictFPComparison(SDValue LHS, SDValue RHS, const SDLoc &DL,
 }
 
 static SDValue emitComparison(SDValue LHS, SDValue RHS, ISD::CondCode CC,
-                              const SDLoc &DL, SelectionDAG &DAG) {
+                              const SDLoc &DL, SelectionDAG &DAG,
+                              bool optimizeMIOrPL = false) {
   EVT VT = LHS.getValueType();
   const bool FullFP16 = DAG.getSubtarget<AArch64Subtarget>().hasFullFP16();
 
@@ -3630,6 +3631,40 @@ static SDValue emitComparison(SDValue LHS, SDValue RHS, ISD::CondCode CC,
       // Use result of ANDS
       return LHS.getValue(1);
     }
+
+    if (LHS.getOpcode() == ISD::SUB) {
+      if (LHS->getFlags().hasNoSignedWrap() ||
+          ((CC == ISD::SETLT || CC == ISD::SETGE) && optimizeMIOrPL)) {
+        const SDValue SUBSNode =
+            DAG.getNode(AArch64ISD::SUBS, DL, DAG.getVTList(VT, FlagsVT),
+                        LHS.getOperand(0), LHS.getOperand(1));
+        // Replace all users of (and X, Y) with newly generated (ands X, Y)
+        DAG.ReplaceAllUsesWith(LHS, SUBSNode);
+        return SUBSNode.getValue(1);
+      }
+    } else if (LHS.getOpcode() == AArch64ISD::SUBS) {
+      if (LHS->getFlags().hasNoSignedWrap() ||
+          ((CC == ISD::SETLT || CC == ISD::SETGE) && optimizeMIOrPL)) {
+        return LHS.getValue(1);
+      }
+    }
+
+    if (LHS.getOpcode() == ISD::ADD) {
+      if (LHS->getFlags().hasNoSignedWrap() ||
+          ((CC == ISD::SETLT || CC == ISD::SETGE) && optimizeMIOrPL)) {
+        const SDValue ADDSNode =
+            DAG.getNode(AArch64ISD::ADDS, DL, DAG.getVTList(VT, FlagsVT),
+                        LHS.getOperand(0), LHS.getOperand(1));
+        // Replace all users of (and X, Y) with newly generated (ands X, Y)
+        DAG.ReplaceAllUsesWith(LHS, ADDSNode);
+        return ADDSNode.getValue(1);
+      }
+    } else if (LHS.getOpcode() == AArch64ISD::ADDS) {
+      if (LHS->getFlags().hasNoSignedWrap() ||
+          ((CC == ISD::SETLT || CC == ISD::SETGE) && optimizeMIOrPL)) {
+        return LHS.getValue(1);
+      }
+    }
   }
 
   return DAG.getNode(Opcode, DL, DAG.getVTList(VT, FlagsVT), LHS, RHS)
@@ -3843,7 +3878,7 @@ static SDValue emitConjunctionRec(SelectionDAG &DAG, SDValue Val,
 
     // Produce a normal comparison if we are first in the chain
     if (!CCOp)
-      return emitComparison(LHS, RHS, CC, DL, DAG);
+      return emitComparison(LHS, RHS, CC, DL, DAG, isInteger);
     // Otherwise produce a ccmp.
     return emitConditionalComparison(LHS, RHS, CC, CCOp, Predicate, OutCC, DL,
                                      DAG);
@@ -4125,7 +4160,7 @@ static SDValue getAArch64Cmp(SDValue LHS, SDValue RHS, ISD::CondCode CC,
   }
 
   if (!Cmp) {
-    Cmp = emitComparison(LHS, RHS, CC, DL, DAG);
+    Cmp = emitComparison(LHS, RHS, CC, DL, DAG, true);
     AArch64CC = changeIntCCToAArch64CC(CC, RHS);
   }
   AArch64cc = getCondCode(DAG, AArch64CC);
diff --git a/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp b/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp
index 0bceb322726d1..a96323e6534eb 100644
--- a/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp
+++ b/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp
@@ -4466,6 +4466,7 @@ MachineInstr *AArch64InstructionSelector::emitIntegerCompare(
   // Fold the compare into a cmn or tst if possible.
   if (auto FoldCmp = tryFoldIntegerCompare(LHS, RHS, Predicate, MIRBuilder))
     return FoldCmp;
+
   auto Dst = MRI.cloneVirtualRegister(LHS.getReg());
   return emitSUBS(Dst, LHS, RHS, MIRBuilder);
 }
@@ -5102,6 +5103,36 @@ MachineInstr *AArch64InstructionSelector::tryFoldIntegerCompare(
     return emitCMN(LHSDef->getOperand(2), RHS, MIRBuilder);
   }
 
+  // Given this:
+  //
+  // z = G_SUB/G_ADD x, y
+  // G_ICMP z, 0
+  //
+  // Produce this if the compare is signed:
+  //
+  // cmp/cmn x, y
+  if ((LHSDef->getFlag(MachineInstr::NoSWrap) && !CmpInst::isUnsigned(P)) ||
+      (P == CmpInst::ICMP_EQ || P == CmpInst::ICMP_NE ||
+       P == CmpInst::ICMP_SLT || P == CmpInst::ICMP_SGE)) {
+
+    if (LHSDef->getOpcode() == TargetOpcode::G_SUB ||
+        LHSDef->getOpcode() == TargetOpcode::G_ADD) {
+      // Make sure that the RHS is 0.
+      auto ValAndVReg = getIConstantVRegValWithLookThrough(RHS.getReg(), MRI);
+      if (!ValAndVReg || ValAndVReg->Value != 0)
+        return nullptr;
+
+      if (LHSDef->getOpcode() == TargetOpcode::G_SUB) {
+        auto Dst = MRI.cloneVirtualRegister(LHS.getReg());
+        return emitSUBS(Dst, LHSDef->getOperand(1), LHSDef->getOperand(2),
+                        MIRBuilder);
+      } else {
+        return emitCMN(LHSDef->getOperand(1), LHSDef->getOperand(2),
+                       MIRBuilder);
+      }
+    }
+  }
+
   // Given this:
   //
   // z = G_AND x, y
diff --git a/llvm/lib/Target/ARM/ARMISelLowering.cpp b/llvm/lib/Target/ARM/ARMISelLowering.cpp
index 12d2d678ff63a..072392bc5315b 100644
--- a/llvm/lib/Target/ARM/ARMISelLowering.cpp
+++ b/llvm/lib/Target/ARM/ARMISelLowering.cpp
@@ -4808,6 +4808,19 @@ SDValue ARMTargetLowering::getARMCmp(SDValue LHS, SDValue RHS, ISD::CondCode CC,
     CompareType = ARMISD::CMPZ;
     break;
   }
+
+  // If we have MI or PL and a sub, we can just do that instead of a CMP.
+  if (CondCode == ARMCC::MI || CondCode == ARMCC::PL || CondCode == ARMCC::EQ ||
+      CondCode == ARMCC::NE ||
+      (LHS.getFlags().hasNoSignedWrap() &&
+       (CondCode == ARMCC::LT || CondCode == ARMCC::GE ||
+        CondCode == ARMCC::LE || CondCode == ARMCC::GT))) {
+    if (LHS.getOpcode() == ISD::SUB) {
+      ARMcc = DAG.getConstant(CondCode, dl, MVT::i32);
+      return DAG.getNode(CompareType, dl, FlagsVT, LHS.getOperand(0),
+                         LHS.getOperand(1));
+    }
+  }
   ARMcc = DAG.getConstant(CondCode, dl, MVT::i32);
   return DAG.getNode(CompareType, dl, FlagsVT, LHS, RHS);
 }

``````````

</details>


https://github.com/llvm/llvm-project/pull/155311


More information about the llvm-commits mailing list