[llvm] 581f837 - [ARM] Fold (fadd x, (vselect c, y, -1.0)) into (vselect c, (fadd x, y), x)

David Green via llvm-commits llvm-commits at lists.llvm.org
Wed Nov 24 02:41:05 PST 2021


Author: David Green
Date: 2021-11-24T10:41:00Z
New Revision: 581f837355b9523bd3217fb05eed3d577d51b95d

URL: https://github.com/llvm/llvm-project/commit/581f837355b9523bd3217fb05eed3d577d51b95d
DIFF: https://github.com/llvm/llvm-project/commit/581f837355b9523bd3217fb05eed3d577d51b95d.diff

LOG: [ARM] Fold (fadd x, (vselect c, y, -1.0)) into (vselect c, (fadd x, y), x)

This is similar to D113574, but as a DAG combine, not tablegen patterns.
Doing the fold as a DAG combine allows the fadd to be folded with a
fmul, finally producing a predicated vfma. It performs the same fold of
fadd(x, vselect(p, y, -0.0)) to vselect p, (fadd x, y), x) using -0.0 as
the identity value of a fadd.

Differential Revision: https://reviews.llvm.org/D113584

Added: 
    

Modified: 
    llvm/lib/Target/ARM/ARMISelLowering.cpp
    llvm/test/CodeGen/Thumb2/mve-pred-selectop3.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/ARM/ARMISelLowering.cpp b/llvm/lib/Target/ARM/ARMISelLowering.cpp
index 239e2270966f..23c1a6e8cf21 100644
--- a/llvm/lib/Target/ARM/ARMISelLowering.cpp
+++ b/llvm/lib/Target/ARM/ARMISelLowering.cpp
@@ -1017,6 +1017,9 @@ ARMTargetLowering::ARMTargetLowering(const TargetMachine &TM,
     setTargetDAGCombine(ISD::SELECT);
     setTargetDAGCombine(ISD::SELECT_CC);
   }
+  if (Subtarget->hasMVEFloatOps()) {
+    setTargetDAGCombine(ISD::FADD);
+  }
 
   if (!Subtarget->hasFP64()) {
     // When targeting a floating-point unit with only single-precision
@@ -16407,6 +16410,42 @@ static SDValue PerformVCVTCombine(SDNode *N, SelectionDAG &DAG,
   return FixConv;
 }
 
+static SDValue PerformFAddVSelectCombine(SDNode *N, SelectionDAG &DAG,
+                                         const ARMSubtarget *Subtarget) {
+  if (!Subtarget->hasMVEFloatOps())
+    return SDValue();
+
+  // Turn (fadd x, (vselect c, y, -0.0)) into (vselect c, (fadd x, y), x)
+  // The second form can be more easily turned into a predicated vadd, and
+  // possibly combined into a fma to become a predicated vfma.
+  SDValue Op0 = N->getOperand(0);
+  SDValue Op1 = N->getOperand(1);
+  EVT VT = N->getValueType(0);
+  SDLoc DL(N);
+
+  // The identity element for a fadd is -0.0, which these VMOV's represent.
+  auto isNegativeZeroSplat = [&](SDValue Op) {
+    if (Op.getOpcode() != ISD::BITCAST ||
+        Op.getOperand(0).getOpcode() != ARMISD::VMOVIMM)
+      return false;
+    if (VT == MVT::v4f32 && Op.getOperand(0).getConstantOperandVal(0) == 1664)
+      return true;
+    if (VT == MVT::v8f16 && Op.getOperand(0).getConstantOperandVal(0) == 2688)
+      return true;
+    return false;
+  };
+
+  if (Op0.getOpcode() == ISD::VSELECT && Op1.getOpcode() != ISD::VSELECT)
+    std::swap(Op0, Op1);
+
+  if (Op1.getOpcode() != ISD::VSELECT ||
+      !isNegativeZeroSplat(Op1.getOperand(2)))
+    return SDValue();
+  SDValue FAdd =
+      DAG.getNode(ISD::FADD, DL, VT, Op0, Op1.getOperand(1), N->getFlags());
+  return DAG.getNode(ISD::VSELECT, DL, VT, Op1.getOperand(0), FAdd, Op0);
+}
+
 /// PerformVDIVCombine - VCVT (fixed-point to floating-point, Advanced SIMD)
 /// can replace combinations of VCVT (integer to floating-point) and VDIV
 /// when the VDIV has a constant operand that is a power of 2.
@@ -18201,6 +18240,8 @@ SDValue ARMTargetLowering::PerformDAGCombine(SDNode *N,
   case ISD::FP_TO_SINT:
   case ISD::FP_TO_UINT:
     return PerformVCVTCombine(N, DCI.DAG, Subtarget);
+  case ISD::FADD:
+    return PerformFAddVSelectCombine(N, DCI.DAG, Subtarget);
   case ISD::FDIV:
     return PerformVDIVCombine(N, DCI.DAG, Subtarget);
   case ISD::INTRINSIC_WO_CHAIN:

diff  --git a/llvm/test/CodeGen/Thumb2/mve-pred-selectop3.ll b/llvm/test/CodeGen/Thumb2/mve-pred-selectop3.ll
index 238133c21e20..e3e23f6524ba 100644
--- a/llvm/test/CodeGen/Thumb2/mve-pred-selectop3.ll
+++ b/llvm/test/CodeGen/Thumb2/mve-pred-selectop3.ll
@@ -470,10 +470,9 @@ entry:
 define arm_aapcs_vfpcc <4 x float> @fma_v4f32_x(<4 x float> %x, <4 x float> %y, <4 x float> %z, i32 %n) {
 ; CHECK-LABEL: fma_v4f32_x:
 ; CHECK:       @ %bb.0: @ %entry
-; CHECK-NEXT:    vmul.f32 q1, q1, q2
 ; CHECK-NEXT:    vctp.32 r0
 ; CHECK-NEXT:    vpst
-; CHECK-NEXT:    vaddt.f32 q0, q0, q1
+; CHECK-NEXT:    vfmat.f32 q0, q1, q2
 ; CHECK-NEXT:    bx lr
 entry:
   %c = call <4 x i1> @llvm.arm.mve.vctp32(i32 %n)
@@ -486,10 +485,9 @@ entry:
 define arm_aapcs_vfpcc <8 x half> @fma_v8f16_x(<8 x half> %x, <8 x half> %y, <8 x half> %z, i32 %n) {
 ; CHECK-LABEL: fma_v8f16_x:
 ; CHECK:       @ %bb.0: @ %entry
-; CHECK-NEXT:    vmul.f16 q1, q1, q2
 ; CHECK-NEXT:    vctp.16 r0
 ; CHECK-NEXT:    vpst
-; CHECK-NEXT:    vaddt.f16 q0, q0, q1
+; CHECK-NEXT:    vfmat.f16 q0, q1, q2
 ; CHECK-NEXT:    bx lr
 entry:
   %c = call <8 x i1> @llvm.arm.mve.vctp16(i32 %n)
@@ -2422,7 +2420,7 @@ define arm_aapcs_vfpcc <4 x float> @faddqr_v4f32_y(<4 x float> %x, float %y, i32
 ; CHECK-NEXT:    vctp.32 r0
 ; CHECK-NEXT:    vdup.32 q1, r1
 ; CHECK-NEXT:    vpst
-; CHECK-NEXT:    vaddt.f32 q1, q1, q0
+; CHECK-NEXT:    vaddt.f32 q1, q0, r1
 ; CHECK-NEXT:    vmov q0, q1
 ; CHECK-NEXT:    bx lr
 entry:
@@ -2441,7 +2439,7 @@ define arm_aapcs_vfpcc <8 x half> @faddqr_v8f16_y(<8 x half> %x, half %y, i32 %n
 ; CHECK-NEXT:    vctp.16 r0
 ; CHECK-NEXT:    vdup.16 q1, r1
 ; CHECK-NEXT:    vpst
-; CHECK-NEXT:    vaddt.f16 q1, q1, q0
+; CHECK-NEXT:    vaddt.f16 q1, q0, r1
 ; CHECK-NEXT:    vmov q0, q1
 ; CHECK-NEXT:    bx lr
 entry:


        


More information about the llvm-commits mailing list