[llvm] [AArch64] Allow folding between CMN and ADDS and other flag setting nodes are commutative (PR #160170)

via llvm-commits llvm-commits at lists.llvm.org
Tue Oct 14 12:19:18 PDT 2025


https://github.com/AZero13 updated https://github.com/llvm/llvm-project/pull/160170

>From 9ee7be0b41adffe22eb7675cc929a52d097d5ee8 Mon Sep 17 00:00:00 2001
From: AZero13 <gfunni234 at gmail.com>
Date: Mon, 22 Sep 2025 14:19:49 -0400
Subject: [PATCH] [AArch64] Explicitly mark ADDS, ANDS, SUBS, etc as binops

---
 .../Target/AArch64/AArch64ISelLowering.cpp    | 22 ++++++++++++++++++
 llvm/lib/Target/AArch64/AArch64ISelLowering.h |  5 ++++
 llvm/lib/Target/AArch64/AArch64InstrInfo.td   |  4 ++--
 llvm/test/CodeGen/AArch64/cmp-to-cmn.ll       | 23 +++++++++++++++++++
 4 files changed, 52 insertions(+), 2 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 9926a4d7baec6..df79e55462c05 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -17194,6 +17194,28 @@ bool AArch64TargetLowering::shouldRemoveRedundantExtend(SDValue Extend) const {
   return true;
 }
 
+bool AArch64TargetLowering::isBinOp(unsigned Opcode) const {
+  switch (Opcode) {
+  // TODO: Add more?
+  case AArch64ISD::SUBS:
+  case AArch64ISD::SBC:
+  case AArch64ISD::SBCS:
+    return true;
+  }
+  return TargetLoweringBase::isBinOp(Opcode);
+}
+
+bool AArch64TargetLowering::isCommutativeBinOp(unsigned Opcode) const {
+  switch (Opcode) {
+  case AArch64ISD::ANDS:
+  case AArch64ISD::ADDS:
+  case AArch64ISD::ADC:
+  case AArch64ISD::ADCS:
+    return true;
+  }
+  return TargetLoweringBase::isCommutativeBinOp(Opcode);
+}
+
 // Truncations from 64-bit GPR to 32-bit GPR is free.
 bool AArch64TargetLowering::isTruncateFree(Type *Ty1, Type *Ty2) const {
   if (!Ty1->isIntegerTy() || !Ty2->isIntegerTy())
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index 00956fdc8e48e..2ceb8f90bab1a 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -250,6 +250,11 @@ class AArch64TargetLowering : public TargetLowering {
   bool isLegalAddScalableImmediate(int64_t) const override;
   bool isLegalICmpImmediate(int64_t) const override;
 
+  /// Add AArch64-specific opcodes to the default list.
+  bool isBinOp(unsigned Opcode) const override;
+
+  bool isCommutativeBinOp(unsigned Opcode) const override;
+
   bool isMulAddWithConstProfitable(SDValue AddNode,
                                    SDValue ConstNode) const override;
 
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
index f788c7510f80c..65e6752aaffae 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
@@ -842,7 +842,7 @@ def AArch64csinc         : SDNode<"AArch64ISD::CSINC", SDT_AArch64CSel>;
 // Return with a glue operand. Operand 0 is the chain operand.
 def AArch64retglue       : SDNode<"AArch64ISD::RET_GLUE", SDTNone,
                                 [SDNPHasChain, SDNPOptInGlue, SDNPVariadic]>;
-def AArch64adc       : SDNode<"AArch64ISD::ADC",  SDTBinaryArithWithFlagsIn >;
+def AArch64adc       : SDNode<"AArch64ISD::ADC",  SDTBinaryArithWithFlagsIn, [SDNPCommutative]>;
 def AArch64sbc       : SDNode<"AArch64ISD::SBC",  SDTBinaryArithWithFlagsIn>;
 
 // Arithmetic instructions which write flags.
@@ -851,7 +851,7 @@ def AArch64add_flag  : SDNode<"AArch64ISD::ADDS",  SDTBinaryArithWithFlagsOut,
 def AArch64sub_flag  : SDNode<"AArch64ISD::SUBS",  SDTBinaryArithWithFlagsOut>;
 def AArch64and_flag  : SDNode<"AArch64ISD::ANDS",  SDTBinaryArithWithFlagsOut,
                             [SDNPCommutative]>;
-def AArch64adc_flag  : SDNode<"AArch64ISD::ADCS",  SDTBinaryArithWithFlagsInOut>;
+def AArch64adc_flag  : SDNode<"AArch64ISD::ADCS",  SDTBinaryArithWithFlagsInOut, [SDNPCommutative]>;
 def AArch64sbc_flag  : SDNode<"AArch64ISD::SBCS",  SDTBinaryArithWithFlagsInOut>;
 
 // Conditional compares. Operands: left,right,falsecc,cc,flags
diff --git a/llvm/test/CodeGen/AArch64/cmp-to-cmn.ll b/llvm/test/CodeGen/AArch64/cmp-to-cmn.ll
index b3ce9d2369104..44a38d7947d66 100644
--- a/llvm/test/CodeGen/AArch64/cmp-to-cmn.ll
+++ b/llvm/test/CodeGen/AArch64/cmp-to-cmn.ll
@@ -843,3 +843,26 @@ define i1 @cmn_nsw_neg_64(i64 %a, i64 %b) {
   %cmp = icmp sgt i64 %a, %sub
   ret i1 %cmp
 }
+
+define i1 @cmn_and_adds(i32 %num, i32 %num2, ptr %use)  {
+; CHECK-SD-LABEL: cmn_and_adds:
+; CHECK-SD:       // %bb.0:
+; CHECK-SD-NEXT:    adds w8, w0, w1
+; CHECK-SD-NEXT:    cset w0, lt
+; CHECK-SD-NEXT:    str w8, [x2]
+; CHECK-SD-NEXT:    ret
+;
+; CHECK-GI-LABEL: cmn_and_adds:
+; CHECK-GI:       // %bb.0:
+; CHECK-GI-NEXT:    cmn w0, w1
+; CHECK-GI-NEXT:    add w9, w1, w0
+; CHECK-GI-NEXT:    cset w8, lt
+; CHECK-GI-NEXT:    str w9, [x2]
+; CHECK-GI-NEXT:    mov w0, w8
+; CHECK-GI-NEXT:    ret
+  %add = add nsw i32 %num2, %num
+  store i32 %add, ptr %use, align 4
+  %sub = sub nsw i32 0, %num2
+  %cmp = icmp slt i32 %num, %sub
+  ret i1 %cmp
+}



More information about the llvm-commits mailing list