[llvm] b4df2b2 - [ARM] Combine fadd into fcmla

David Green via llvm-commits llvm-commits at lists.llvm.org
Wed Apr 5 02:31:26 PDT 2023


Author: David Green
Date: 2023-04-05T10:31:19+01:00
New Revision: b4df2b2c6c75f02e2c26c3c947718816ca0be56a

URL: https://github.com/llvm/llvm-project/commit/b4df2b2c6c75f02e2c26c3c947718816ca0be56a
DIFF: https://github.com/llvm/llvm-project/commit/b4df2b2c6c75f02e2c26c3c947718816ca0be56a.diff

LOG: [ARM] Combine fadd into fcmla

This is the MVE equivalent of https://reviews.llvm.org/D146407. It adds a
target combine for fadd(a, vcmla(b, c, d)) -> vcmla(fadd(a, b), c, d), pushing
the fadd into the operands of the fcmla, which can help simplify away some
additions.

Differential Revision: https://reviews.llvm.org/D147200

Added: 
    

Modified: 
    llvm/lib/Target/ARM/ARMISelLowering.cpp
    llvm/test/CodeGen/Thumb2/mve-complex-deinterleaving-mixed-cases.ll
    llvm/test/CodeGen/Thumb2/mve-vcmla.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/ARM/ARMISelLowering.cpp b/llvm/lib/Target/ARM/ARMISelLowering.cpp
index e6fd02ef2714b..a6b92593c4958 100644
--- a/llvm/lib/Target/ARM/ARMISelLowering.cpp
+++ b/llvm/lib/Target/ARM/ARMISelLowering.cpp
@@ -16884,6 +16884,46 @@ static SDValue PerformFAddVSelectCombine(SDNode *N, SelectionDAG &DAG,
   return DAG.getNode(ISD::VSELECT, DL, VT, Op1.getOperand(0), FAdd, Op0, FaddFlags);
 }
 
+static SDValue PerformFADDVCMLACombine(SDNode *N, SelectionDAG &DAG) {
+  SDValue LHS = N->getOperand(0);
+  SDValue RHS = N->getOperand(1);
+  EVT VT = N->getValueType(0);
+  SDLoc DL(N);
+
+  if (!N->getFlags().hasAllowReassociation())
+    return SDValue();
+
+  // Combine fadd(a, vcmla(b, c, d)) -> vcmla(fadd(a, b), b, c)
+  auto ReassocComplex = [&](SDValue A, SDValue B) {
+    if (A.getOpcode() != ISD::INTRINSIC_WO_CHAIN)
+      return SDValue();
+    unsigned Opc = A.getConstantOperandVal(0);
+    if (Opc != Intrinsic::arm_mve_vcmlaq)
+      return SDValue();
+    SDValue VCMLA = DAG.getNode(
+        ISD::INTRINSIC_WO_CHAIN, DL, VT, A.getOperand(0), A.getOperand(1),
+        DAG.getNode(ISD::FADD, DL, VT, A.getOperand(2), B, N->getFlags()),
+        A.getOperand(3), A.getOperand(4));
+    VCMLA->setFlags(A->getFlags());
+    return VCMLA;
+  };
+  if (SDValue R = ReassocComplex(LHS, RHS))
+    return R;
+  if (SDValue R = ReassocComplex(RHS, LHS))
+    return R;
+
+  return SDValue();
+}
+
+static SDValue PerformFADDCombine(SDNode *N, SelectionDAG &DAG,
+                                  const ARMSubtarget *Subtarget) {
+  if (SDValue S = PerformFAddVSelectCombine(N, DAG, Subtarget))
+    return S;
+  if (SDValue S = PerformFADDVCMLACombine(N, DAG))
+    return S;
+  return SDValue();
+}
+
 /// PerformVDIVCombine - VCVT (fixed-point to floating-point, Advanced SIMD)
 /// can replace combinations of VCVT (integer to floating-point) and VDIV
 /// when the VDIV has a constant operand that is a power of 2.
@@ -18771,7 +18811,7 @@ SDValue ARMTargetLowering::PerformDAGCombine(SDNode *N,
   case ISD::FP_TO_UINT:
     return PerformVCVTCombine(N, DCI.DAG, Subtarget);
   case ISD::FADD:
-    return PerformFAddVSelectCombine(N, DCI.DAG, Subtarget);
+    return PerformFADDCombine(N, DCI.DAG, Subtarget);
   case ISD::FDIV:
     return PerformVDIVCombine(N, DCI.DAG, Subtarget);
   case ISD::INTRINSIC_WO_CHAIN:

diff  --git a/llvm/test/CodeGen/Thumb2/mve-complex-deinterleaving-mixed-cases.ll b/llvm/test/CodeGen/Thumb2/mve-complex-deinterleaving-mixed-cases.ll
index a529aa81467e0..152aaa26595cf 100644
--- a/llvm/test/CodeGen/Thumb2/mve-complex-deinterleaving-mixed-cases.ll
+++ b/llvm/test/CodeGen/Thumb2/mve-complex-deinterleaving-mixed-cases.ll
@@ -391,16 +391,16 @@ define <4 x float> @mul_addequal(<4 x float> %a, <4 x float> %b, <4 x float> %c)
 ; CHECK-LABEL: mul_addequal:
 ; CHECK:       @ %bb.0: @ %entry
 ; CHECK-NEXT:    vmov d0, r0, r1
-; CHECK-NEXT:    mov r1, sp
-; CHECK-NEXT:    vldrw.u32 q2, [r1]
-; CHECK-NEXT:    vmov d1, r2, r3
-; CHECK-NEXT:    add r0, sp, #16
-; CHECK-NEXT:    vcmul.f32 q3, q0, q2, #0
+; CHECK-NEXT:    mov r0, sp
+; CHECK-NEXT:    add r1, sp, #16
 ; CHECK-NEXT:    vldrw.u32 q1, [r0]
-; CHECK-NEXT:    vcmla.f32 q3, q0, q2, #90
-; CHECK-NEXT:    vadd.f32 q0, q3, q1
-; CHECK-NEXT:    vmov r0, r1, d0
-; CHECK-NEXT:    vmov r2, r3, d1
+; CHECK-NEXT:    vmov d1, r2, r3
+; CHECK-NEXT:    vldrw.u32 q2, [r1]
+; CHECK-NEXT:    vcmul.f32 q3, q0, q1, #0
+; CHECK-NEXT:    vadd.f32 q2, q3, q2
+; CHECK-NEXT:    vcmla.f32 q2, q0, q1, #90
+; CHECK-NEXT:    vmov r0, r1, d4
+; CHECK-NEXT:    vmov r2, r3, d5
 ; CHECK-NEXT:    bx lr
 entry:
   %strided.vec = shufflevector <4 x float> %a, <4 x float> poison, <2 x i32> <i32 0, i32 2>

diff  --git a/llvm/test/CodeGen/Thumb2/mve-vcmla.ll b/llvm/test/CodeGen/Thumb2/mve-vcmla.ll
index d0d65bdf836ba..26b5fdbfad9ac 100644
--- a/llvm/test/CodeGen/Thumb2/mve-vcmla.ll
+++ b/llvm/test/CodeGen/Thumb2/mve-vcmla.ll
@@ -10,9 +10,7 @@ declare <4 x float> @llvm.arm.mve.vcmulq.v4f32(i32, <4 x float>, <4 x float>)
 define arm_aapcs_vfpcc <4 x float> @reassoc_f32x4(<4 x float> %a, <4 x float> %b, <4 x float> %c) {
 ; CHECK-LABEL: reassoc_f32x4:
 ; CHECK:       @ %bb.0: @ %entry
-; CHECK-NEXT:    vmov.i32 q3, #0x0
-; CHECK-NEXT:    vcmla.f32 q3, q1, q2, #0
-; CHECK-NEXT:    vadd.f32 q0, q3, q0
+; CHECK-NEXT:    vcmla.f32 q0, q1, q2, #0
 ; CHECK-NEXT:    bx lr
 entry:
   %d = tail call <4 x float> @llvm.arm.mve.vcmlaq.v4f32(i32 0, <4 x float> zeroinitializer, <4 x float> %b, <4 x float> %c)
@@ -23,9 +21,7 @@ entry:
 define arm_aapcs_vfpcc <4 x float> @reassoc_c_f32x4(<4 x float> %a, <4 x float> %b, <4 x float> %c) {
 ; CHECK-LABEL: reassoc_c_f32x4:
 ; CHECK:       @ %bb.0: @ %entry
-; CHECK-NEXT:    vmov.i32 q3, #0x0
-; CHECK-NEXT:    vcmla.f32 q3, q1, q2, #90
-; CHECK-NEXT:    vadd.f32 q0, q0, q3
+; CHECK-NEXT:    vcmla.f32 q0, q1, q2, #90
 ; CHECK-NEXT:    bx lr
 entry:
   %d = tail call <4 x float> @llvm.arm.mve.vcmlaq.v4f32(i32 1, <4 x float> zeroinitializer, <4 x float> %b, <4 x float> %c)
@@ -36,9 +32,7 @@ entry:
 define arm_aapcs_vfpcc <8 x half> @reassoc_f16x4(<8 x half> %a, <8 x half> %b, <8 x half> %c) {
 ; CHECK-LABEL: reassoc_f16x4:
 ; CHECK:       @ %bb.0: @ %entry
-; CHECK-NEXT:    vmov.i32 q3, #0x0
-; CHECK-NEXT:    vcmla.f16 q3, q1, q2, #180
-; CHECK-NEXT:    vadd.f16 q0, q3, q0
+; CHECK-NEXT:    vcmla.f16 q0, q1, q2, #180
 ; CHECK-NEXT:    bx lr
 entry:
   %d = tail call <8 x half> @llvm.arm.mve.vcmlaq.v8f16(i32 2, <8 x half> zeroinitializer, <8 x half> %b, <8 x half> %c)
@@ -49,9 +43,7 @@ entry:
 define arm_aapcs_vfpcc <8 x half> @reassoc_c_f16x4(<8 x half> %a, <8 x half> %b, <8 x half> %c) {
 ; CHECK-LABEL: reassoc_c_f16x4:
 ; CHECK:       @ %bb.0: @ %entry
-; CHECK-NEXT:    vmov.i32 q3, #0x0
-; CHECK-NEXT:    vcmla.f16 q3, q1, q2, #270
-; CHECK-NEXT:    vadd.f16 q0, q0, q3
+; CHECK-NEXT:    vcmla.f16 q0, q1, q2, #270
 ; CHECK-NEXT:    bx lr
 entry:
   %d = tail call <8 x half> @llvm.arm.mve.vcmlaq.v8f16(i32 3, <8 x half> zeroinitializer, <8 x half> %b, <8 x half> %c)


        


More information about the llvm-commits mailing list