[llvm] [AArch64] Allow folding between CMN and ADDS and other flag setting nodes are commutative (PR #160170)

via llvm-commits llvm-commits at lists.llvm.org
Mon Sep 22 13:06:42 PDT 2025


https://github.com/AZero13 updated https://github.com/llvm/llvm-project/pull/160170

>From 3ae88e5cb63f5ec024af78f669ce58a5e3e3c42d Mon Sep 17 00:00:00 2001
From: AZero13 <gfunni234 at gmail.com>
Date: Mon, 22 Sep 2025 14:19:49 -0400
Subject: [PATCH 1/2] Pre-commit test (NFC)

---
 llvm/test/CodeGen/AArch64/cmp-to-cmn.ll | 16 ++++++++++++++++
 1 file changed, 16 insertions(+)

diff --git a/llvm/test/CodeGen/AArch64/cmp-to-cmn.ll b/llvm/test/CodeGen/AArch64/cmp-to-cmn.ll
index b3ce9d2369104..b5afb90ed5fbf 100644
--- a/llvm/test/CodeGen/AArch64/cmp-to-cmn.ll
+++ b/llvm/test/CodeGen/AArch64/cmp-to-cmn.ll
@@ -843,3 +843,19 @@ define i1 @cmn_nsw_neg_64(i64 %a, i64 %b) {
   %cmp = icmp sgt i64 %a, %sub
   ret i1 %cmp
 }
+
+define i1 @cmn_and_adds(i32 %num, i32 %num2, ptr %use)  {
+; CHECK-LABEL: cmn_and_adds:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    cmn w0, w1
+; CHECK-NEXT:    add w9, w1, w0
+; CHECK-NEXT:    cset w8, lt
+; CHECK-NEXT:    str w9, [x2]
+; CHECK-NEXT:    mov w0, w8
+; CHECK-NEXT:    ret
+  %add = add nsw i32 %num2, %num
+  store i32 %add, ptr %use, align 4
+  %sub = sub nsw i32 0, %num2
+  %cmp = icmp slt i32 %num, %sub
+  ret i1 %cmp
+}

>From d27785b8ee203ea3f1ce0cb69bef04aa184ade30 Mon Sep 17 00:00:00 2001
From: AZero13 <gfunni234 at gmail.com>
Date: Mon, 22 Sep 2025 14:29:44 -0400
Subject: [PATCH 2/2] [AArch64] Allow folding between CMN and ADDS and other
 flag setting nodes if the operands are commutative

---
 llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp |  7 ++++
 .../Target/AArch64/AArch64ISelLowering.cpp    | 35 +++++++++++++++++++
 llvm/lib/Target/AArch64/AArch64ISelLowering.h |  5 +++
 llvm/lib/Target/AArch64/AArch64InstrInfo.td   |  4 +--
 llvm/test/CodeGen/AArch64/aarch64-smull.ll    | 16 ++++-----
 llvm/test/CodeGen/AArch64/adds_cmn.ll         |  6 ++--
 llvm/test/CodeGen/AArch64/arm64-vmul.ll       | 34 ++++++++++++------
 llvm/test/CodeGen/AArch64/cmp-to-cmn.ll       | 23 +++++++-----
 llvm/test/CodeGen/AArch64/sat-add.ll          |  6 ++--
 9 files changed, 99 insertions(+), 37 deletions(-)

diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 0c773e7dcb5de..cc86d4d31e74f 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -25841,6 +25841,13 @@ static SDValue narrowExtractedVectorBinOp(EVT VT, SDValue Src, unsigned Index,
   if (WideNumElts % NarrowingRatio != 0)
     return SDValue();
 
+  // Bail out if the binop has different element types between operands and
+  // result. This prevents incorrect transforms for ops like PMULL where result
+  // element type differs from operand element types.
+  if (BinOp.getOperand(0).getValueType().getScalarType() !=
+      WideBVT.getScalarType())
+    return SDValue();
+
   // Bail out if the target does not support a narrower version of the binop.
   EVT NarrowBVT = EVT::getVectorVT(*DAG.getContext(), WideBVT.getScalarType(),
                                    WideNumElts / NarrowingRatio);
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index cd7f0e719ad0c..9a148e574ca5c 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -17107,6 +17107,31 @@ bool AArch64TargetLowering::shouldRemoveRedundantExtend(SDValue Extend) const {
   return true;
 }
 
+bool AArch64TargetLowering::isBinOp(unsigned Opcode) const {
+  switch (Opcode) {
+  // TODO: Add more?
+  case AArch64ISD::SUBS:
+  case AArch64ISD::SBC:
+  case AArch64ISD::SBCS:
+    return true;
+  }
+  return TargetLoweringBase::isBinOp(Opcode);
+}
+
+bool AArch64TargetLowering::isCommutativeBinOp(unsigned Opcode) const {
+  switch (Opcode) {
+  case AArch64ISD::ANDS:
+  case AArch64ISD::ADDS:
+  case AArch64ISD::ADC:
+  case AArch64ISD::ADCS:
+  case AArch64ISD::PMULL:
+  case AArch64ISD::SMULL:
+  case AArch64ISD::UMULL:
+    return true;
+  }
+  return TargetLoweringBase::isCommutativeBinOp(Opcode);
+}
+
 // Truncations from 64-bit GPR to 32-bit GPR is free.
 bool AArch64TargetLowering::isTruncateFree(Type *Ty1, Type *Ty2) const {
   if (!Ty1->isIntegerTy() || !Ty2->isIntegerTy())
@@ -26013,6 +26038,16 @@ static SDValue performFlagSettingCombine(SDNode *N,
           GenericOpcode, DCI.DAG.getVTList(VT), {LHS, RHS}))
     DCI.CombineTo(Generic, SDValue(N, 0));
 
+  // Not every non-commutative opcode isn't commutative. By that, ADCS is not
+  // considered commutative by the rest of the codebase as ADCS has a
+  // non-commutative flag. However, other than that, the operands don't matter
+  // for ADCS.
+  if (isCommutativeBinOp(GenericOpcode)) {
+    if (SDNode *Generic = DCI.DAG.getNodeIfExists(
+            GenericOpcode, DCI.DAG.getVTList(VT), {RHS, LHS}))
+      DCI.CombineTo(Generic, SDValue(N, 0));
+  }
+
   return SDValue();
 }
 
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index ff073d3eafb1f..c1ec703acca33 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -250,6 +250,11 @@ class AArch64TargetLowering : public TargetLowering {
   bool isLegalAddScalableImmediate(int64_t) const override;
   bool isLegalICmpImmediate(int64_t) const override;
 
+  /// Add AArch64-specific opcodes to the default list.
+  bool isBinOp(unsigned Opcode) const override;
+
+  bool isCommutativeBinOp(unsigned Opcode) const override;
+
   bool isMulAddWithConstProfitable(SDValue AddNode,
                                    SDValue ConstNode) const override;
 
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
index 6cea453f271be..e2cb3a2262bcc 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
@@ -842,7 +842,7 @@ def AArch64csinc         : SDNode<"AArch64ISD::CSINC", SDT_AArch64CSel>;
 // Return with a glue operand. Operand 0 is the chain operand.
 def AArch64retglue       : SDNode<"AArch64ISD::RET_GLUE", SDTNone,
                                 [SDNPHasChain, SDNPOptInGlue, SDNPVariadic]>;
-def AArch64adc       : SDNode<"AArch64ISD::ADC",  SDTBinaryArithWithFlagsIn >;
+def AArch64adc       : SDNode<"AArch64ISD::ADC",  SDTBinaryArithWithFlagsIn, [SDNPCommutative]>;
 def AArch64sbc       : SDNode<"AArch64ISD::SBC",  SDTBinaryArithWithFlagsIn>;
 
 // Arithmetic instructions which write flags.
@@ -851,7 +851,7 @@ def AArch64add_flag  : SDNode<"AArch64ISD::ADDS",  SDTBinaryArithWithFlagsOut,
 def AArch64sub_flag  : SDNode<"AArch64ISD::SUBS",  SDTBinaryArithWithFlagsOut>;
 def AArch64and_flag  : SDNode<"AArch64ISD::ANDS",  SDTBinaryArithWithFlagsOut,
                             [SDNPCommutative]>;
-def AArch64adc_flag  : SDNode<"AArch64ISD::ADCS",  SDTBinaryArithWithFlagsInOut>;
+def AArch64adc_flag  : SDNode<"AArch64ISD::ADCS",  SDTBinaryArithWithFlagsInOut, [SDNPCommutative]>;
 def AArch64sbc_flag  : SDNode<"AArch64ISD::SBCS",  SDTBinaryArithWithFlagsInOut>;
 
 // Conditional compares. Operands: left,right,falsecc,cc,flags
diff --git a/llvm/test/CodeGen/AArch64/aarch64-smull.ll b/llvm/test/CodeGen/AArch64/aarch64-smull.ll
index 6e5c666bdbc75..72a4377034f21 100644
--- a/llvm/test/CodeGen/AArch64/aarch64-smull.ll
+++ b/llvm/test/CodeGen/AArch64/aarch64-smull.ll
@@ -2459,14 +2459,14 @@ define <8 x i16> @sdistribute_const1_v8i8(<8 x i8> %src1, <8 x i8> %mul) {
 ; CHECK-NEON:       // %bb.0: // %entry
 ; CHECK-NEON-NEXT:    movi v2.8b, #10
 ; CHECK-NEON-NEXT:    smull v0.8h, v0.8b, v1.8b
-; CHECK-NEON-NEXT:    smlal v0.8h, v2.8b, v1.8b
+; CHECK-NEON-NEXT:    smlal v0.8h, v1.8b, v2.8b
 ; CHECK-NEON-NEXT:    ret
 ;
 ; CHECK-SVE-LABEL: sdistribute_const1_v8i8:
 ; CHECK-SVE:       // %bb.0: // %entry
 ; CHECK-SVE-NEXT:    movi v2.8b, #10
 ; CHECK-SVE-NEXT:    smull v0.8h, v0.8b, v1.8b
-; CHECK-SVE-NEXT:    smlal v0.8h, v2.8b, v1.8b
+; CHECK-SVE-NEXT:    smlal v0.8h, v1.8b, v2.8b
 ; CHECK-SVE-NEXT:    ret
 ;
 ; CHECK-GI-LABEL: sdistribute_const1_v8i8:
@@ -2546,14 +2546,14 @@ define <8 x i16> @udistribute_const1_v8i8(<8 x i8> %src1, <8 x i8> %mul) {
 ; CHECK-NEON:       // %bb.0: // %entry
 ; CHECK-NEON-NEXT:    movi v2.8b, #10
 ; CHECK-NEON-NEXT:    umull v0.8h, v0.8b, v1.8b
-; CHECK-NEON-NEXT:    umlal v0.8h, v2.8b, v1.8b
+; CHECK-NEON-NEXT:    umlal v0.8h, v1.8b, v2.8b
 ; CHECK-NEON-NEXT:    ret
 ;
 ; CHECK-SVE-LABEL: udistribute_const1_v8i8:
 ; CHECK-SVE:       // %bb.0: // %entry
 ; CHECK-SVE-NEXT:    movi v2.8b, #10
 ; CHECK-SVE-NEXT:    umull v0.8h, v0.8b, v1.8b
-; CHECK-SVE-NEXT:    umlal v0.8h, v2.8b, v1.8b
+; CHECK-SVE-NEXT:    umlal v0.8h, v1.8b, v2.8b
 ; CHECK-SVE-NEXT:    ret
 ;
 ; CHECK-GI-LABEL: udistribute_const1_v8i8:
@@ -2779,14 +2779,14 @@ define <2 x i64> @sdistribute_const1_v2i32(<2 x i32> %src1, <2 x i32> %mul) {
 ; CHECK-NEON:       // %bb.0: // %entry
 ; CHECK-NEON-NEXT:    movi v2.2s, #10
 ; CHECK-NEON-NEXT:    smull v0.2d, v0.2s, v1.2s
-; CHECK-NEON-NEXT:    smlal v0.2d, v2.2s, v1.2s
+; CHECK-NEON-NEXT:    smlal v0.2d, v1.2s, v2.2s
 ; CHECK-NEON-NEXT:    ret
 ;
 ; CHECK-SVE-LABEL: sdistribute_const1_v2i32:
 ; CHECK-SVE:       // %bb.0: // %entry
 ; CHECK-SVE-NEXT:    movi v2.2s, #10
 ; CHECK-SVE-NEXT:    smull v0.2d, v0.2s, v1.2s
-; CHECK-SVE-NEXT:    smlal v0.2d, v2.2s, v1.2s
+; CHECK-SVE-NEXT:    smlal v0.2d, v1.2s, v2.2s
 ; CHECK-SVE-NEXT:    ret
 ;
 ; CHECK-GI-LABEL: sdistribute_const1_v2i32:
@@ -2889,14 +2889,14 @@ define <2 x i64> @udistribute_const1_v2i32(<2 x i32> %src1, <2 x i32> %mul) {
 ; CHECK-NEON:       // %bb.0: // %entry
 ; CHECK-NEON-NEXT:    movi v2.2s, #10
 ; CHECK-NEON-NEXT:    umull v0.2d, v0.2s, v1.2s
-; CHECK-NEON-NEXT:    umlal v0.2d, v2.2s, v1.2s
+; CHECK-NEON-NEXT:    umlal v0.2d, v1.2s, v2.2s
 ; CHECK-NEON-NEXT:    ret
 ;
 ; CHECK-SVE-LABEL: udistribute_const1_v2i32:
 ; CHECK-SVE:       // %bb.0: // %entry
 ; CHECK-SVE-NEXT:    movi v2.2s, #10
 ; CHECK-SVE-NEXT:    umull v0.2d, v0.2s, v1.2s
-; CHECK-SVE-NEXT:    umlal v0.2d, v2.2s, v1.2s
+; CHECK-SVE-NEXT:    umlal v0.2d, v1.2s, v2.2s
 ; CHECK-SVE-NEXT:    ret
 ;
 ; CHECK-GI-LABEL: udistribute_const1_v2i32:
diff --git a/llvm/test/CodeGen/AArch64/adds_cmn.ll b/llvm/test/CodeGen/AArch64/adds_cmn.ll
index aa070b7886ba5..9b456a5419d61 100644
--- a/llvm/test/CodeGen/AArch64/adds_cmn.ll
+++ b/llvm/test/CodeGen/AArch64/adds_cmn.ll
@@ -22,10 +22,8 @@ entry:
 define { i32, i32 } @adds_cmn_c(i32 noundef %x, i32 noundef %y) {
 ; CHECK-LABEL: adds_cmn_c:
 ; CHECK:       // %bb.0: // %entry
-; CHECK-NEXT:    cmn w0, w1
-; CHECK-NEXT:    add w1, w1, w0
-; CHECK-NEXT:    cset w8, lo
-; CHECK-NEXT:    mov w0, w8
+; CHECK-NEXT:    adds w1, w0, w1
+; CHECK-NEXT:    cset w0, lo
 ; CHECK-NEXT:    ret
 entry:
   %0 = tail call { i32, i1 } @llvm.uadd.with.overflow.i32(i32 %x, i32 %y)
diff --git a/llvm/test/CodeGen/AArch64/arm64-vmul.ll b/llvm/test/CodeGen/AArch64/arm64-vmul.ll
index e6df9f2fb2c56..bd63548741c3b 100644
--- a/llvm/test/CodeGen/AArch64/arm64-vmul.ll
+++ b/llvm/test/CodeGen/AArch64/arm64-vmul.ll
@@ -81,11 +81,17 @@ define <2 x i64> @smull2d(ptr %A, ptr %B) nounwind {
 }
 
 define void @commutable_smull(<2 x i32> %A, <2 x i32> %B, ptr %C) {
-; CHECK-LABEL: commutable_smull:
-; CHECK:       // %bb.0:
-; CHECK-NEXT:    smull v0.2d, v0.2s, v1.2s
-; CHECK-NEXT:    stp q0, q0, [x0]
-; CHECK-NEXT:    ret
+; CHECK-SD-LABEL: commutable_smull:
+; CHECK-SD:       // %bb.0:
+; CHECK-SD-NEXT:    smull v0.2d, v1.2s, v0.2s
+; CHECK-SD-NEXT:    stp q0, q0, [x0]
+; CHECK-SD-NEXT:    ret
+;
+; CHECK-GI-LABEL: commutable_smull:
+; CHECK-GI:       // %bb.0:
+; CHECK-GI-NEXT:    smull v0.2d, v0.2s, v1.2s
+; CHECK-GI-NEXT:    stp q0, q0, [x0]
+; CHECK-GI-NEXT:    ret
   %1 = call <2 x i64> @llvm.aarch64.neon.smull.v2i64(<2 x i32> %A, <2 x i32> %B)
   %2 = call <2 x i64> @llvm.aarch64.neon.smull.v2i64(<2 x i32> %B, <2 x i32> %A)
   store <2 x i64> %1, ptr %C
@@ -138,11 +144,17 @@ define <2 x i64> @umull2d(ptr %A, ptr %B) nounwind {
 }
 
 define void @commutable_umull(<2 x i32> %A, <2 x i32> %B, ptr %C) {
-; CHECK-LABEL: commutable_umull:
-; CHECK:       // %bb.0:
-; CHECK-NEXT:    umull v0.2d, v0.2s, v1.2s
-; CHECK-NEXT:    stp q0, q0, [x0]
-; CHECK-NEXT:    ret
+; CHECK-SD-LABEL: commutable_umull:
+; CHECK-SD:       // %bb.0:
+; CHECK-SD-NEXT:    umull v0.2d, v1.2s, v0.2s
+; CHECK-SD-NEXT:    stp q0, q0, [x0]
+; CHECK-SD-NEXT:    ret
+;
+; CHECK-GI-LABEL: commutable_umull:
+; CHECK-GI:       // %bb.0:
+; CHECK-GI-NEXT:    umull v0.2d, v0.2s, v1.2s
+; CHECK-GI-NEXT:    stp q0, q0, [x0]
+; CHECK-GI-NEXT:    ret
   %1 = call <2 x i64> @llvm.aarch64.neon.umull.v2i64(<2 x i32> %A, <2 x i32> %B)
   %2 = call <2 x i64> @llvm.aarch64.neon.umull.v2i64(<2 x i32> %B, <2 x i32> %A)
   store <2 x i64> %1, ptr %C
@@ -245,7 +257,7 @@ define <8 x i16> @pmull8h(ptr %A, ptr %B) nounwind {
 define void @commutable_pmull8h(<8 x i8> %A, <8 x i8> %B, ptr %C) {
 ; CHECK-LABEL: commutable_pmull8h:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    pmull v0.8h, v0.8b, v1.8b
+; CHECK-NEXT:    pmull v0.8h, v1.8b, v0.8b
 ; CHECK-NEXT:    stp q0, q0, [x0]
 ; CHECK-NEXT:    ret
   %1 = call <8 x i16> @llvm.aarch64.neon.pmull.v8i16(<8 x i8> %A, <8 x i8> %B)
diff --git a/llvm/test/CodeGen/AArch64/cmp-to-cmn.ll b/llvm/test/CodeGen/AArch64/cmp-to-cmn.ll
index b5afb90ed5fbf..44a38d7947d66 100644
--- a/llvm/test/CodeGen/AArch64/cmp-to-cmn.ll
+++ b/llvm/test/CodeGen/AArch64/cmp-to-cmn.ll
@@ -845,14 +845,21 @@ define i1 @cmn_nsw_neg_64(i64 %a, i64 %b) {
 }
 
 define i1 @cmn_and_adds(i32 %num, i32 %num2, ptr %use)  {
-; CHECK-LABEL: cmn_and_adds:
-; CHECK:       // %bb.0:
-; CHECK-NEXT:    cmn w0, w1
-; CHECK-NEXT:    add w9, w1, w0
-; CHECK-NEXT:    cset w8, lt
-; CHECK-NEXT:    str w9, [x2]
-; CHECK-NEXT:    mov w0, w8
-; CHECK-NEXT:    ret
+; CHECK-SD-LABEL: cmn_and_adds:
+; CHECK-SD:       // %bb.0:
+; CHECK-SD-NEXT:    adds w8, w0, w1
+; CHECK-SD-NEXT:    cset w0, lt
+; CHECK-SD-NEXT:    str w8, [x2]
+; CHECK-SD-NEXT:    ret
+;
+; CHECK-GI-LABEL: cmn_and_adds:
+; CHECK-GI:       // %bb.0:
+; CHECK-GI-NEXT:    cmn w0, w1
+; CHECK-GI-NEXT:    add w9, w1, w0
+; CHECK-GI-NEXT:    cset w8, lt
+; CHECK-GI-NEXT:    str w9, [x2]
+; CHECK-GI-NEXT:    mov w0, w8
+; CHECK-GI-NEXT:    ret
   %add = add nsw i32 %num2, %num
   store i32 %add, ptr %use, align 4
   %sub = sub nsw i32 0, %num2
diff --git a/llvm/test/CodeGen/AArch64/sat-add.ll b/llvm/test/CodeGen/AArch64/sat-add.ll
index ecd48d6b7c65b..149b4c4fd26c9 100644
--- a/llvm/test/CodeGen/AArch64/sat-add.ll
+++ b/llvm/test/CodeGen/AArch64/sat-add.ll
@@ -290,8 +290,7 @@ define i32 @unsigned_sat_variable_i32_using_cmp_sum(i32 %x, i32 %y) {
 define i32 @unsigned_sat_variable_i32_using_cmp_notval(i32 %x, i32 %y) {
 ; CHECK-LABEL: unsigned_sat_variable_i32_using_cmp_notval:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    add w8, w0, w1
-; CHECK-NEXT:    cmn w1, w0
+; CHECK-NEXT:    adds w8, w1, w0
 ; CHECK-NEXT:    csinv w0, w8, wzr, lo
 ; CHECK-NEXT:    ret
   %noty = xor i32 %y, -1
@@ -331,8 +330,7 @@ define i64 @unsigned_sat_variable_i64_using_cmp_sum(i64 %x, i64 %y) {
 define i64 @unsigned_sat_variable_i64_using_cmp_notval(i64 %x, i64 %y) {
 ; CHECK-LABEL: unsigned_sat_variable_i64_using_cmp_notval:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    add x8, x0, x1
-; CHECK-NEXT:    cmn x1, x0
+; CHECK-NEXT:    adds x8, x1, x0
 ; CHECK-NEXT:    csinv x0, x8, xzr, lo
 ; CHECK-NEXT:    ret
   %noty = xor i64 %y, -1



More information about the llvm-commits mailing list