[llvm] [AArch64] Allow folding between CMN and ADDS and other flag setting nodes are commutative (PR #160170)
via llvm-commits
llvm-commits at lists.llvm.org
Tue Sep 23 05:44:24 PDT 2025
https://github.com/AZero13 updated https://github.com/llvm/llvm-project/pull/160170
>From 3ae88e5cb63f5ec024af78f669ce58a5e3e3c42d Mon Sep 17 00:00:00 2001
From: AZero13 <gfunni234 at gmail.com>
Date: Mon, 22 Sep 2025 14:19:49 -0400
Subject: [PATCH 1/2] Pre-commit test (NFC)
---
llvm/test/CodeGen/AArch64/cmp-to-cmn.ll | 16 ++++++++++++++++
1 file changed, 16 insertions(+)
diff --git a/llvm/test/CodeGen/AArch64/cmp-to-cmn.ll b/llvm/test/CodeGen/AArch64/cmp-to-cmn.ll
index b3ce9d2369104..b5afb90ed5fbf 100644
--- a/llvm/test/CodeGen/AArch64/cmp-to-cmn.ll
+++ b/llvm/test/CodeGen/AArch64/cmp-to-cmn.ll
@@ -843,3 +843,19 @@ define i1 @cmn_nsw_neg_64(i64 %a, i64 %b) {
%cmp = icmp sgt i64 %a, %sub
ret i1 %cmp
}
+
+define i1 @cmn_and_adds(i32 %num, i32 %num2, ptr %use) {
+; CHECK-LABEL: cmn_and_adds:
+; CHECK: // %bb.0:
+; CHECK-NEXT: cmn w0, w1
+; CHECK-NEXT: add w9, w1, w0
+; CHECK-NEXT: cset w8, lt
+; CHECK-NEXT: str w9, [x2]
+; CHECK-NEXT: mov w0, w8
+; CHECK-NEXT: ret
+ %add = add nsw i32 %num2, %num
+ store i32 %add, ptr %use, align 4
+ %sub = sub nsw i32 0, %num2
+ %cmp = icmp slt i32 %num, %sub
+ ret i1 %cmp
+}
>From 7ee4433b793829b3692632cc232576bba0ac011c Mon Sep 17 00:00:00 2001
From: AZero13 <gfunni234 at gmail.com>
Date: Mon, 22 Sep 2025 14:29:44 -0400
Subject: [PATCH 2/2] [AArch64] Allow folding between CMN and ADDS and other
flag setting nodes if the operands are commutative
---
.../Target/AArch64/AArch64ISelLowering.cpp | 38 +++++++++++++++++--
llvm/lib/Target/AArch64/AArch64ISelLowering.h | 11 ++++++
llvm/lib/Target/AArch64/AArch64InstrInfo.td | 4 +-
llvm/test/CodeGen/AArch64/adds_cmn.ll | 6 +--
llvm/test/CodeGen/AArch64/cmp-to-cmn.ll | 23 +++++++----
llvm/test/CodeGen/AArch64/sat-add.ll | 6 +--
6 files changed, 67 insertions(+), 21 deletions(-)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index cd7f0e719ad0c..250074f462c3e 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -17107,6 +17107,28 @@ bool AArch64TargetLowering::shouldRemoveRedundantExtend(SDValue Extend) const {
return true;
}
+bool AArch64TargetLowering::isBinOp(unsigned Opcode) const {
+ switch (Opcode) {
+ // TODO: Add more?
+ case AArch64ISD::SUBS:
+ case AArch64ISD::SBC:
+ case AArch64ISD::SBCS:
+ return true;
+ }
+ return TargetLoweringBase::isBinOp(Opcode);
+}
+
+bool AArch64TargetLowering::isCommutativeBinOp(unsigned Opcode) const {
+ switch (Opcode) {
+ case AArch64ISD::ANDS:
+ case AArch64ISD::ADDS:
+ case AArch64ISD::ADC:
+ case AArch64ISD::ADCS:
+ return true;
+ }
+ return TargetLoweringBase::isCommutativeBinOp(Opcode);
+}
+
// Truncations from 64-bit GPR to 32-bit GPR is free.
bool AArch64TargetLowering::isTruncateFree(Type *Ty1, Type *Ty2) const {
if (!Ty1->isIntegerTy() || !Ty2->isIntegerTy())
@@ -25994,9 +26016,9 @@ static SDValue performSETCCCombine(SDNode *N,
// Replace a flag-setting operator (eg ANDS) with the generic version
// (eg AND) if the flag is unused.
-static SDValue performFlagSettingCombine(SDNode *N,
- TargetLowering::DAGCombinerInfo &DCI,
- unsigned GenericOpcode) {
+SDValue AArch64TargetLowering::performFlagSettingCombine(
+ SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
+ unsigned GenericOpcode) const {
SDLoc DL(N);
SDValue LHS = N->getOperand(0);
SDValue RHS = N->getOperand(1);
@@ -26013,6 +26035,16 @@ static SDValue performFlagSettingCombine(SDNode *N,
GenericOpcode, DCI.DAG.getVTList(VT), {LHS, RHS}))
DCI.CombineTo(Generic, SDValue(N, 0));
+ // Not every non-commutative opcode isn't commutative. By that, ADCS is not
+ // considered commutative by the rest of the codebase as ADCS has a
+ // non-commutative flag. However, other than that, the operands don't matter
+ // for ADCS.
+ if (isCommutativeBinOp(GenericOpcode)) {
+ if (SDNode *Generic = DCI.DAG.getNodeIfExists(
+ GenericOpcode, DCI.DAG.getVTList(VT), {RHS, LHS}))
+ DCI.CombineTo(Generic, SDValue(N, 0));
+ }
+
return SDValue();
}
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index ff073d3eafb1f..c955ae11697d6 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -250,6 +250,11 @@ class AArch64TargetLowering : public TargetLowering {
bool isLegalAddScalableImmediate(int64_t) const override;
bool isLegalICmpImmediate(int64_t) const override;
+ /// Add AArch64-specific opcodes to the default list.
+ bool isBinOp(unsigned Opcode) const override;
+
+ bool isCommutativeBinOp(unsigned Opcode) const override;
+
bool isMulAddWithConstProfitable(SDValue AddNode,
SDValue ConstNode) const override;
@@ -913,6 +918,12 @@ class AArch64TargetLowering : public TargetLowering {
bool hasMultipleConditionRegisters(EVT VT) const override {
return VT.isScalableVector();
}
+
+ // Replace a flag-setting operator (eg ANDS) with the generic version
+ // (eg AND) if the flag is unused.
+ SDValue performFlagSettingCombine(SDNode *N,
+ TargetLowering::DAGCombinerInfo &DCI,
+ unsigned GenericOpcode) const;
};
namespace AArch64 {
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
index 6cea453f271be..e2cb3a2262bcc 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
@@ -842,7 +842,7 @@ def AArch64csinc : SDNode<"AArch64ISD::CSINC", SDT_AArch64CSel>;
// Return with a glue operand. Operand 0 is the chain operand.
def AArch64retglue : SDNode<"AArch64ISD::RET_GLUE", SDTNone,
[SDNPHasChain, SDNPOptInGlue, SDNPVariadic]>;
-def AArch64adc : SDNode<"AArch64ISD::ADC", SDTBinaryArithWithFlagsIn >;
+def AArch64adc : SDNode<"AArch64ISD::ADC", SDTBinaryArithWithFlagsIn, [SDNPCommutative]>;
def AArch64sbc : SDNode<"AArch64ISD::SBC", SDTBinaryArithWithFlagsIn>;
// Arithmetic instructions which write flags.
@@ -851,7 +851,7 @@ def AArch64add_flag : SDNode<"AArch64ISD::ADDS", SDTBinaryArithWithFlagsOut,
def AArch64sub_flag : SDNode<"AArch64ISD::SUBS", SDTBinaryArithWithFlagsOut>;
def AArch64and_flag : SDNode<"AArch64ISD::ANDS", SDTBinaryArithWithFlagsOut,
[SDNPCommutative]>;
-def AArch64adc_flag : SDNode<"AArch64ISD::ADCS", SDTBinaryArithWithFlagsInOut>;
+def AArch64adc_flag : SDNode<"AArch64ISD::ADCS", SDTBinaryArithWithFlagsInOut, [SDNPCommutative]>;
def AArch64sbc_flag : SDNode<"AArch64ISD::SBCS", SDTBinaryArithWithFlagsInOut>;
// Conditional compares. Operands: left,right,falsecc,cc,flags
diff --git a/llvm/test/CodeGen/AArch64/adds_cmn.ll b/llvm/test/CodeGen/AArch64/adds_cmn.ll
index aa070b7886ba5..9b456a5419d61 100644
--- a/llvm/test/CodeGen/AArch64/adds_cmn.ll
+++ b/llvm/test/CodeGen/AArch64/adds_cmn.ll
@@ -22,10 +22,8 @@ entry:
define { i32, i32 } @adds_cmn_c(i32 noundef %x, i32 noundef %y) {
; CHECK-LABEL: adds_cmn_c:
; CHECK: // %bb.0: // %entry
-; CHECK-NEXT: cmn w0, w1
-; CHECK-NEXT: add w1, w1, w0
-; CHECK-NEXT: cset w8, lo
-; CHECK-NEXT: mov w0, w8
+; CHECK-NEXT: adds w1, w0, w1
+; CHECK-NEXT: cset w0, lo
; CHECK-NEXT: ret
entry:
%0 = tail call { i32, i1 } @llvm.uadd.with.overflow.i32(i32 %x, i32 %y)
diff --git a/llvm/test/CodeGen/AArch64/cmp-to-cmn.ll b/llvm/test/CodeGen/AArch64/cmp-to-cmn.ll
index b5afb90ed5fbf..44a38d7947d66 100644
--- a/llvm/test/CodeGen/AArch64/cmp-to-cmn.ll
+++ b/llvm/test/CodeGen/AArch64/cmp-to-cmn.ll
@@ -845,14 +845,21 @@ define i1 @cmn_nsw_neg_64(i64 %a, i64 %b) {
}
define i1 @cmn_and_adds(i32 %num, i32 %num2, ptr %use) {
-; CHECK-LABEL: cmn_and_adds:
-; CHECK: // %bb.0:
-; CHECK-NEXT: cmn w0, w1
-; CHECK-NEXT: add w9, w1, w0
-; CHECK-NEXT: cset w8, lt
-; CHECK-NEXT: str w9, [x2]
-; CHECK-NEXT: mov w0, w8
-; CHECK-NEXT: ret
+; CHECK-SD-LABEL: cmn_and_adds:
+; CHECK-SD: // %bb.0:
+; CHECK-SD-NEXT: adds w8, w0, w1
+; CHECK-SD-NEXT: cset w0, lt
+; CHECK-SD-NEXT: str w8, [x2]
+; CHECK-SD-NEXT: ret
+;
+; CHECK-GI-LABEL: cmn_and_adds:
+; CHECK-GI: // %bb.0:
+; CHECK-GI-NEXT: cmn w0, w1
+; CHECK-GI-NEXT: add w9, w1, w0
+; CHECK-GI-NEXT: cset w8, lt
+; CHECK-GI-NEXT: str w9, [x2]
+; CHECK-GI-NEXT: mov w0, w8
+; CHECK-GI-NEXT: ret
%add = add nsw i32 %num2, %num
store i32 %add, ptr %use, align 4
%sub = sub nsw i32 0, %num2
diff --git a/llvm/test/CodeGen/AArch64/sat-add.ll b/llvm/test/CodeGen/AArch64/sat-add.ll
index ecd48d6b7c65b..149b4c4fd26c9 100644
--- a/llvm/test/CodeGen/AArch64/sat-add.ll
+++ b/llvm/test/CodeGen/AArch64/sat-add.ll
@@ -290,8 +290,7 @@ define i32 @unsigned_sat_variable_i32_using_cmp_sum(i32 %x, i32 %y) {
define i32 @unsigned_sat_variable_i32_using_cmp_notval(i32 %x, i32 %y) {
; CHECK-LABEL: unsigned_sat_variable_i32_using_cmp_notval:
; CHECK: // %bb.0:
-; CHECK-NEXT: add w8, w0, w1
-; CHECK-NEXT: cmn w1, w0
+; CHECK-NEXT: adds w8, w1, w0
; CHECK-NEXT: csinv w0, w8, wzr, lo
; CHECK-NEXT: ret
%noty = xor i32 %y, -1
@@ -331,8 +330,7 @@ define i64 @unsigned_sat_variable_i64_using_cmp_sum(i64 %x, i64 %y) {
define i64 @unsigned_sat_variable_i64_using_cmp_notval(i64 %x, i64 %y) {
; CHECK-LABEL: unsigned_sat_variable_i64_using_cmp_notval:
; CHECK: // %bb.0:
-; CHECK-NEXT: add x8, x0, x1
-; CHECK-NEXT: cmn x1, x0
+; CHECK-NEXT: adds x8, x1, x0
; CHECK-NEXT: csinv x0, x8, xzr, lo
; CHECK-NEXT: ret
%noty = xor i64 %y, -1
More information about the llvm-commits
mailing list