[llvm] e9eaee9 - [AArch64] Reassociate sub(x, add(m1, m2)) to sub(sub(x, m1), m2)

David Green via llvm-commits llvm-commits at lists.llvm.org
Mon Feb 13 06:35:17 PST 2023


Author: David Green
Date: 2023-02-13T14:35:10Z
New Revision: e9eaee9da196265d20dbeaf7920c24ccb33e2d04

URL: https://github.com/llvm/llvm-project/commit/e9eaee9da196265d20dbeaf7920c24ccb33e2d04
DIFF: https://github.com/llvm/llvm-project/commit/e9eaee9da196265d20dbeaf7920c24ccb33e2d04.diff

LOG: [AArch64] Reassociate sub(x, add(m1, m2)) to sub(sub(x, m1), m2)

The mid end will reassociate sub(sub(x, m1), m2) to sub(x, add(m1, m2)). This
reassociates it back to allow the creation of more mls instructions.

Differential Revision: https://reviews.llvm.org/D143143

Added: 
    

Modified: 
    llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
    llvm/test/CodeGen/AArch64/arm64-vmul.ll
    llvm/test/CodeGen/AArch64/reassocmls.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 7ad92aac3aab..4db2b10ed8bb 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -17703,6 +17703,32 @@ static SDValue performAddCombineForShiftedOperands(SDNode *N,
   return SDValue();
 }
 
+// The mid end will reassociate sub(sub(x, m1), m2) to sub(x, add(m1, m2))
+// This reassociates it back to allow the creation of more mls instructions.
+static SDValue performSubAddMULCombine(SDNode *N, SelectionDAG &DAG) {
+  if (N->getOpcode() != ISD::SUB)
+    return SDValue();
+  SDValue Add = N->getOperand(1);
+  if (Add.getOpcode() != ISD::ADD)
+    return SDValue();
+
+  SDValue X = N->getOperand(0);
+  if (DAG.isConstantIntBuildVectorOrConstantInt(peekThroughBitcasts(X)))
+    return SDValue();
+  SDValue M1 = Add.getOperand(0);
+  SDValue M2 = Add.getOperand(1);
+  if (M1.getOpcode() != ISD::MUL && M1.getOpcode() != AArch64ISD::SMULL &&
+      M1.getOpcode() != AArch64ISD::UMULL)
+    return SDValue();
+  if (M2.getOpcode() != ISD::MUL && M2.getOpcode() != AArch64ISD::SMULL &&
+      M2.getOpcode() != AArch64ISD::UMULL)
+    return SDValue();
+
+  EVT VT = N->getValueType(0);
+  SDValue Sub = DAG.getNode(ISD::SUB, SDLoc(N), VT, X, M1);
+  return DAG.getNode(ISD::SUB, SDLoc(N), VT, Sub, M2);
+}
+
 static SDValue performAddSubCombine(SDNode *N,
                                     TargetLowering::DAGCombinerInfo &DCI,
                                     SelectionDAG &DAG) {
@@ -17719,6 +17745,8 @@ static SDValue performAddSubCombine(SDNode *N,
     return Val;
   if (SDValue Val = performAddCombineForShiftedOperands(N, DAG))
     return Val;
+  if (SDValue Val = performSubAddMULCombine(N, DAG))
+    return Val;
 
   return performAddSubLongCombine(N, DCI, DAG);
 }

diff  --git a/llvm/test/CodeGen/AArch64/arm64-vmul.ll b/llvm/test/CodeGen/AArch64/arm64-vmul.ll
index 7f743f605f25..3a9f0319b06e 100644
--- a/llvm/test/CodeGen/AArch64/arm64-vmul.ll
+++ b/llvm/test/CodeGen/AArch64/arm64-vmul.ll
@@ -457,12 +457,11 @@ define <2 x i64> @smlsl2d(ptr %A, ptr %B, ptr %C) nounwind {
 define void @smlsl8h_chain_with_constant(ptr %dst, <8 x i8> %v1, <8 x i8> %v2, <8 x i8> %v3) {
 ; CHECK-LABEL: smlsl8h_chain_with_constant:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    smull.8h v0, v0, v2
-; CHECK-NEXT:    mvn.8b v2, v2
 ; CHECK-NEXT:    movi.16b v3, #1
-; CHECK-NEXT:    smlal.8h v0, v1, v2
-; CHECK-NEXT:    sub.8h v0, v3, v0
-; CHECK-NEXT:    str q0, [x0]
+; CHECK-NEXT:    smlsl.8h v3, v0, v2
+; CHECK-NEXT:    mvn.8b v0, v2
+; CHECK-NEXT:    smlsl.8h v3, v1, v0
+; CHECK-NEXT:    str q3, [x0]
 ; CHECK-NEXT:    ret
   %xor = xor <8 x i8> %v3, <i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1>
   %smull.1 = tail call <8 x i16> @llvm.aarch64.neon.smull.v8i16(<8 x i8> %v1, <8 x i8> %v3)
@@ -476,13 +475,12 @@ define void @smlsl8h_chain_with_constant(ptr %dst, <8 x i8> %v1, <8 x i8> %v2, <
 define void @smlsl2d_chain_with_constant(ptr %dst, <2 x i32> %v1, <2 x i32> %v2, <2 x i32> %v3) {
 ; CHECK-LABEL: smlsl2d_chain_with_constant:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    smull.2d v0, v0, v2
 ; CHECK-NEXT:    mov w8, #257
-; CHECK-NEXT:    mvn.8b v2, v2
-; CHECK-NEXT:    smlal.2d v0, v1, v2
-; CHECK-NEXT:    dup.2d v1, x8
-; CHECK-NEXT:    sub.2d v0, v1, v0
-; CHECK-NEXT:    str q0, [x0]
+; CHECK-NEXT:    dup.2d v3, x8
+; CHECK-NEXT:    smlsl.2d v3, v0, v2
+; CHECK-NEXT:    mvn.8b v0, v2
+; CHECK-NEXT:    smlsl.2d v3, v1, v0
+; CHECK-NEXT:    str q3, [x0]
 ; CHECK-NEXT:    ret
   %xor = xor <2 x i32> %v3, <i32 -1, i32 -1>
   %smull.1 = tail call <2 x i64> @llvm.aarch64.neon.smull.v2i64(<2 x i32> %v1, <2 x i32> %v3)
@@ -738,12 +736,11 @@ define <2 x i64> @umlsl2d(ptr %A, ptr %B, ptr %C) nounwind {
 define void @umlsl8h_chain_with_constant(ptr %dst, <8 x i8> %v1, <8 x i8> %v2, <8 x i8> %v3) {
 ; CHECK-LABEL: umlsl8h_chain_with_constant:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    umull.8h v0, v0, v2
-; CHECK-NEXT:    mvn.8b v2, v2
 ; CHECK-NEXT:    movi.16b v3, #1
-; CHECK-NEXT:    umlal.8h v0, v1, v2
-; CHECK-NEXT:    sub.8h v0, v3, v0
-; CHECK-NEXT:    str q0, [x0]
+; CHECK-NEXT:    umlsl.8h v3, v0, v2
+; CHECK-NEXT:    mvn.8b v0, v2
+; CHECK-NEXT:    umlsl.8h v3, v1, v0
+; CHECK-NEXT:    str q3, [x0]
 ; CHECK-NEXT:    ret
   %xor = xor <8 x i8> %v3, <i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1>
   %umull.1 = tail call <8 x i16> @llvm.aarch64.neon.umull.v8i16(<8 x i8> %v1, <8 x i8> %v3)
@@ -757,13 +754,12 @@ define void @umlsl8h_chain_with_constant(ptr %dst, <8 x i8> %v1, <8 x i8> %v2, <
 define void @umlsl2d_chain_with_constant(ptr %dst, <2 x i32> %v1, <2 x i32> %v2, <2 x i32> %v3) {
 ; CHECK-LABEL: umlsl2d_chain_with_constant:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    umull.2d v0, v0, v2
 ; CHECK-NEXT:    mov w8, #257
-; CHECK-NEXT:    mvn.8b v2, v2
-; CHECK-NEXT:    umlal.2d v0, v1, v2
-; CHECK-NEXT:    dup.2d v1, x8
-; CHECK-NEXT:    sub.2d v0, v1, v0
-; CHECK-NEXT:    str q0, [x0]
+; CHECK-NEXT:    dup.2d v3, x8
+; CHECK-NEXT:    umlsl.2d v3, v0, v2
+; CHECK-NEXT:    mvn.8b v0, v2
+; CHECK-NEXT:    umlsl.2d v3, v1, v0
+; CHECK-NEXT:    str q3, [x0]
 ; CHECK-NEXT:    ret
   %xor = xor <2 x i32> %v3, <i32 -1, i32 -1>
   %umull.1 = tail call <2 x i64> @llvm.aarch64.neon.umull.v2i64(<2 x i32> %v1, <2 x i32> %v3)

diff  --git a/llvm/test/CodeGen/AArch64/reassocmls.ll b/llvm/test/CodeGen/AArch64/reassocmls.ll
index cf201caac4ab..731d973d0017 100644
--- a/llvm/test/CodeGen/AArch64/reassocmls.ll
+++ b/llvm/test/CodeGen/AArch64/reassocmls.ll
@@ -4,9 +4,8 @@
 define i64 @smlsl_i64(i64 %a, i32 %b, i32 %c, i32 %d, i32 %e) {
 ; CHECK-LABEL: smlsl_i64:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    smull x8, w4, w3
-; CHECK-NEXT:    smaddl x8, w2, w1, x8
-; CHECK-NEXT:    sub x0, x0, x8
+; CHECK-NEXT:    smsubl x8, w4, w3, x0
+; CHECK-NEXT:    smsubl x0, w2, w1, x8
 ; CHECK-NEXT:    ret
   %be = sext i32 %b to i64
   %ce = sext i32 %c to i64
@@ -22,9 +21,8 @@ define i64 @smlsl_i64(i64 %a, i32 %b, i32 %c, i32 %d, i32 %e) {
 define i64 @umlsl_i64(i64 %a, i32 %b, i32 %c, i32 %d, i32 %e) {
 ; CHECK-LABEL: umlsl_i64:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    umull x8, w4, w3
-; CHECK-NEXT:    umaddl x8, w2, w1, x8
-; CHECK-NEXT:    sub x0, x0, x8
+; CHECK-NEXT:    umsubl x8, w4, w3, x0
+; CHECK-NEXT:    umsubl x0, w2, w1, x8
 ; CHECK-NEXT:    ret
   %be = zext i32 %b to i64
   %ce = zext i32 %c to i64
@@ -40,9 +38,8 @@ define i64 @umlsl_i64(i64 %a, i32 %b, i32 %c, i32 %d, i32 %e) {
 define i64 @mls_i64(i64 %a, i64 %b, i64 %c, i64 %d, i64 %e) {
 ; CHECK-LABEL: mls_i64:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    mul x8, x2, x1
-; CHECK-NEXT:    madd x8, x4, x3, x8
-; CHECK-NEXT:    sub x0, x0, x8
+; CHECK-NEXT:    msub x8, x4, x3, x0
+; CHECK-NEXT:    msub x0, x2, x1, x8
 ; CHECK-NEXT:    ret
   %m1.neg = mul i64 %c, %b
   %m2.neg = mul i64 %e, %d
@@ -54,9 +51,8 @@ define i64 @mls_i64(i64 %a, i64 %b, i64 %c, i64 %d, i64 %e) {
 define i16 @mls_i16(i16 %a, i16 %b, i16 %c, i16 %d, i16 %e) {
 ; CHECK-LABEL: mls_i16:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    mul w8, w2, w1
-; CHECK-NEXT:    madd w8, w4, w3, w8
-; CHECK-NEXT:    sub w0, w0, w8
+; CHECK-NEXT:    msub w8, w4, w3, w0
+; CHECK-NEXT:    msub w0, w2, w1, w8
 ; CHECK-NEXT:    ret
   %m1.neg = mul i16 %c, %b
   %m2.neg = mul i16 %e, %d
@@ -97,9 +93,8 @@ define i64 @mls_i64_C(i64 %a, i64 %b, i64 %c, i64 %d, i64 %e) {
 define <8 x i16> @smlsl_v8i16(<8 x i16> %a, <8 x i8> %b, <8 x i8> %c, <8 x i8> %d, <8 x i8> %e) {
 ; CHECK-LABEL: smlsl_v8i16:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    smull v3.8h, v4.8b, v3.8b
-; CHECK-NEXT:    smlal v3.8h, v2.8b, v1.8b
-; CHECK-NEXT:    sub v0.8h, v0.8h, v3.8h
+; CHECK-NEXT:    smlsl v0.8h, v4.8b, v3.8b
+; CHECK-NEXT:    smlsl v0.8h, v2.8b, v1.8b
 ; CHECK-NEXT:    ret
   %be = sext <8 x i8> %b to <8 x i16>
   %ce = sext <8 x i8> %c to <8 x i16>
@@ -115,9 +110,8 @@ define <8 x i16> @smlsl_v8i16(<8 x i16> %a, <8 x i8> %b, <8 x i8> %c, <8 x i8> %
 define <8 x i16> @umlsl_v8i16(<8 x i16> %a, <8 x i8> %b, <8 x i8> %c, <8 x i8> %d, <8 x i8> %e) {
 ; CHECK-LABEL: umlsl_v8i16:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    umull v3.8h, v4.8b, v3.8b
-; CHECK-NEXT:    umlal v3.8h, v2.8b, v1.8b
-; CHECK-NEXT:    sub v0.8h, v0.8h, v3.8h
+; CHECK-NEXT:    umlsl v0.8h, v4.8b, v3.8b
+; CHECK-NEXT:    umlsl v0.8h, v2.8b, v1.8b
 ; CHECK-NEXT:    ret
   %be = zext <8 x i8> %b to <8 x i16>
   %ce = zext <8 x i8> %c to <8 x i16>
@@ -133,9 +127,8 @@ define <8 x i16> @umlsl_v8i16(<8 x i16> %a, <8 x i8> %b, <8 x i8> %c, <8 x i8> %
 define <8 x i16> @mls_v8i16(<8 x i16> %a, <8 x i16> %b, <8 x i16> %c, <8 x i16> %d, <8 x i16> %e) {
 ; CHECK-LABEL: mls_v8i16:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    mul v1.8h, v2.8h, v1.8h
-; CHECK-NEXT:    mla v1.8h, v4.8h, v3.8h
-; CHECK-NEXT:    sub v0.8h, v0.8h, v1.8h
+; CHECK-NEXT:    mls v0.8h, v4.8h, v3.8h
+; CHECK-NEXT:    mls v0.8h, v2.8h, v1.8h
 ; CHECK-NEXT:    ret
   %m1.neg = mul <8 x i16> %c, %b
   %m2.neg = mul <8 x i16> %e, %d
@@ -157,6 +150,20 @@ define <8 x i16> @mla_v8i16(<8 x i16> %a, <8 x i16> %b, <8 x i16> %c, <8 x i16>
   ret <8 x i16> %s2
 }
 
+define <8 x i16> @mls_v8i16_C(<8 x i16> %a, <8 x i16> %b, <8 x i16> %c, <8 x i16> %d, <8 x i16> %e) {
+; CHECK-LABEL: mls_v8i16_C:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    movi v0.8h, #10
+; CHECK-NEXT:    mls v0.8h, v4.8h, v3.8h
+; CHECK-NEXT:    mls v0.8h, v2.8h, v1.8h
+; CHECK-NEXT:    ret
+  %m1.neg = mul <8 x i16> %c, %b
+  %m2.neg = mul <8 x i16> %e, %d
+  %reass.add = add <8 x i16> %m2.neg, %m1.neg
+  %s2 = sub <8 x i16> <i16 10, i16 10, i16 10, i16 10, i16 10, i16 10, i16 10, i16 10>, %reass.add
+  ret <8 x i16> %s2
+}
+
 
 define <vscale x 8 x i16> @smlsl_nxv8i16(<vscale x 8 x i16> %a, <vscale x 8 x i8> %b, <vscale x 8 x i8> %c, <vscale x 8 x i8> %d, <vscale x 8 x i8> %e) {
 ; CHECK-LABEL: smlsl_nxv8i16:
@@ -166,9 +173,8 @@ define <vscale x 8 x i16> @smlsl_nxv8i16(<vscale x 8 x i16> %a, <vscale x 8 x i8
 ; CHECK-NEXT:    sxtb z4.h, p0/m, z4.h
 ; CHECK-NEXT:    sxtb z1.h, p0/m, z1.h
 ; CHECK-NEXT:    sxtb z2.h, p0/m, z2.h
-; CHECK-NEXT:    mul z3.h, z4.h, z3.h
-; CHECK-NEXT:    mla z3.h, p0/m, z2.h, z1.h
-; CHECK-NEXT:    sub z0.h, z0.h, z3.h
+; CHECK-NEXT:    mls z0.h, p0/m, z4.h, z3.h
+; CHECK-NEXT:    mls z0.h, p0/m, z2.h, z1.h
 ; CHECK-NEXT:    ret
   %be = sext <vscale x 8 x i8> %b to <vscale x 8 x i16>
   %ce = sext <vscale x 8 x i8> %c to <vscale x 8 x i16>
@@ -184,14 +190,13 @@ define <vscale x 8 x i16> @smlsl_nxv8i16(<vscale x 8 x i16> %a, <vscale x 8 x i8
 define <vscale x 8 x i16> @umlsl_nxv8i16(<vscale x 8 x i16> %a, <vscale x 8 x i8> %b, <vscale x 8 x i8> %c, <vscale x 8 x i8> %d, <vscale x 8 x i8> %e) {
 ; CHECK-LABEL: umlsl_nxv8i16:
 ; CHECK:       // %bb.0:
+; CHECK-NEXT:    ptrue p0.h
 ; CHECK-NEXT:    and z3.h, z3.h, #0xff
 ; CHECK-NEXT:    and z4.h, z4.h, #0xff
-; CHECK-NEXT:    ptrue p0.h
 ; CHECK-NEXT:    and z1.h, z1.h, #0xff
 ; CHECK-NEXT:    and z2.h, z2.h, #0xff
-; CHECK-NEXT:    mul z3.h, z4.h, z3.h
-; CHECK-NEXT:    mla z3.h, p0/m, z2.h, z1.h
-; CHECK-NEXT:    sub z0.h, z0.h, z3.h
+; CHECK-NEXT:    mls z0.h, p0/m, z4.h, z3.h
+; CHECK-NEXT:    mls z0.h, p0/m, z2.h, z1.h
 ; CHECK-NEXT:    ret
   %be = zext <vscale x 8 x i8> %b to <vscale x 8 x i16>
   %ce = zext <vscale x 8 x i8> %c to <vscale x 8 x i16>
@@ -208,9 +213,8 @@ define <vscale x 8 x i16> @mls_nxv8i16(<vscale x 8 x i16> %a, <vscale x 8 x i16>
 ; CHECK-LABEL: mls_nxv8i16:
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    ptrue p0.h
-; CHECK-NEXT:    mul z3.h, z4.h, z3.h
-; CHECK-NEXT:    mla z3.h, p0/m, z2.h, z1.h
-; CHECK-NEXT:    sub z0.h, z0.h, z3.h
+; CHECK-NEXT:    mls z0.h, p0/m, z4.h, z3.h
+; CHECK-NEXT:    mls z0.h, p0/m, z2.h, z1.h
 ; CHECK-NEXT:    ret
   %m1.neg = mul <vscale x 8 x i16> %c, %b
   %m2.neg = mul <vscale x 8 x i16> %e, %d


        


More information about the llvm-commits mailing list