[llvm] 581f837 - [ARM] Fold (fadd x, (vselect c, y, -1.0)) into (vselect c, (fadd x, y), x)
David Green via llvm-commits
llvm-commits at lists.llvm.org
Wed Nov 24 02:41:05 PST 2021
Author: David Green
Date: 2021-11-24T10:41:00Z
New Revision: 581f837355b9523bd3217fb05eed3d577d51b95d
URL: https://github.com/llvm/llvm-project/commit/581f837355b9523bd3217fb05eed3d577d51b95d
DIFF: https://github.com/llvm/llvm-project/commit/581f837355b9523bd3217fb05eed3d577d51b95d.diff
LOG: [ARM] Fold (fadd x, (vselect c, y, -1.0)) into (vselect c, (fadd x, y), x)
This is similar to D113574, but as a DAG combine, not tablegen patterns.
Doing the fold as a DAG combine allows the fadd to be folded with a
fmul, finally producing a predicated vfma. It performs the same fold of
fadd(x, vselect(p, y, -0.0)) to vselect p, (fadd x, y), x) using -0.0 as
the identity value of a fadd.
Differential Revision: https://reviews.llvm.org/D113584
Added:
Modified:
llvm/lib/Target/ARM/ARMISelLowering.cpp
llvm/test/CodeGen/Thumb2/mve-pred-selectop3.ll
Removed:
################################################################################
diff --git a/llvm/lib/Target/ARM/ARMISelLowering.cpp b/llvm/lib/Target/ARM/ARMISelLowering.cpp
index 239e2270966f..23c1a6e8cf21 100644
--- a/llvm/lib/Target/ARM/ARMISelLowering.cpp
+++ b/llvm/lib/Target/ARM/ARMISelLowering.cpp
@@ -1017,6 +1017,9 @@ ARMTargetLowering::ARMTargetLowering(const TargetMachine &TM,
setTargetDAGCombine(ISD::SELECT);
setTargetDAGCombine(ISD::SELECT_CC);
}
+ if (Subtarget->hasMVEFloatOps()) {
+ setTargetDAGCombine(ISD::FADD);
+ }
if (!Subtarget->hasFP64()) {
// When targeting a floating-point unit with only single-precision
@@ -16407,6 +16410,42 @@ static SDValue PerformVCVTCombine(SDNode *N, SelectionDAG &DAG,
return FixConv;
}
+static SDValue PerformFAddVSelectCombine(SDNode *N, SelectionDAG &DAG,
+ const ARMSubtarget *Subtarget) {
+ if (!Subtarget->hasMVEFloatOps())
+ return SDValue();
+
+ // Turn (fadd x, (vselect c, y, -0.0)) into (vselect c, (fadd x, y), x)
+ // The second form can be more easily turned into a predicated vadd, and
+ // possibly combined into a fma to become a predicated vfma.
+ SDValue Op0 = N->getOperand(0);
+ SDValue Op1 = N->getOperand(1);
+ EVT VT = N->getValueType(0);
+ SDLoc DL(N);
+
+ // The identity element for a fadd is -0.0, which these VMOV's represent.
+ auto isNegativeZeroSplat = [&](SDValue Op) {
+ if (Op.getOpcode() != ISD::BITCAST ||
+ Op.getOperand(0).getOpcode() != ARMISD::VMOVIMM)
+ return false;
+ if (VT == MVT::v4f32 && Op.getOperand(0).getConstantOperandVal(0) == 1664)
+ return true;
+ if (VT == MVT::v8f16 && Op.getOperand(0).getConstantOperandVal(0) == 2688)
+ return true;
+ return false;
+ };
+
+ if (Op0.getOpcode() == ISD::VSELECT && Op1.getOpcode() != ISD::VSELECT)
+ std::swap(Op0, Op1);
+
+ if (Op1.getOpcode() != ISD::VSELECT ||
+ !isNegativeZeroSplat(Op1.getOperand(2)))
+ return SDValue();
+ SDValue FAdd =
+ DAG.getNode(ISD::FADD, DL, VT, Op0, Op1.getOperand(1), N->getFlags());
+ return DAG.getNode(ISD::VSELECT, DL, VT, Op1.getOperand(0), FAdd, Op0);
+}
+
/// PerformVDIVCombine - VCVT (fixed-point to floating-point, Advanced SIMD)
/// can replace combinations of VCVT (integer to floating-point) and VDIV
/// when the VDIV has a constant operand that is a power of 2.
@@ -18201,6 +18240,8 @@ SDValue ARMTargetLowering::PerformDAGCombine(SDNode *N,
case ISD::FP_TO_SINT:
case ISD::FP_TO_UINT:
return PerformVCVTCombine(N, DCI.DAG, Subtarget);
+ case ISD::FADD:
+ return PerformFAddVSelectCombine(N, DCI.DAG, Subtarget);
case ISD::FDIV:
return PerformVDIVCombine(N, DCI.DAG, Subtarget);
case ISD::INTRINSIC_WO_CHAIN:
diff --git a/llvm/test/CodeGen/Thumb2/mve-pred-selectop3.ll b/llvm/test/CodeGen/Thumb2/mve-pred-selectop3.ll
index 238133c21e20..e3e23f6524ba 100644
--- a/llvm/test/CodeGen/Thumb2/mve-pred-selectop3.ll
+++ b/llvm/test/CodeGen/Thumb2/mve-pred-selectop3.ll
@@ -470,10 +470,9 @@ entry:
define arm_aapcs_vfpcc <4 x float> @fma_v4f32_x(<4 x float> %x, <4 x float> %y, <4 x float> %z, i32 %n) {
; CHECK-LABEL: fma_v4f32_x:
; CHECK: @ %bb.0: @ %entry
-; CHECK-NEXT: vmul.f32 q1, q1, q2
; CHECK-NEXT: vctp.32 r0
; CHECK-NEXT: vpst
-; CHECK-NEXT: vaddt.f32 q0, q0, q1
+; CHECK-NEXT: vfmat.f32 q0, q1, q2
; CHECK-NEXT: bx lr
entry:
%c = call <4 x i1> @llvm.arm.mve.vctp32(i32 %n)
@@ -486,10 +485,9 @@ entry:
define arm_aapcs_vfpcc <8 x half> @fma_v8f16_x(<8 x half> %x, <8 x half> %y, <8 x half> %z, i32 %n) {
; CHECK-LABEL: fma_v8f16_x:
; CHECK: @ %bb.0: @ %entry
-; CHECK-NEXT: vmul.f16 q1, q1, q2
; CHECK-NEXT: vctp.16 r0
; CHECK-NEXT: vpst
-; CHECK-NEXT: vaddt.f16 q0, q0, q1
+; CHECK-NEXT: vfmat.f16 q0, q1, q2
; CHECK-NEXT: bx lr
entry:
%c = call <8 x i1> @llvm.arm.mve.vctp16(i32 %n)
@@ -2422,7 +2420,7 @@ define arm_aapcs_vfpcc <4 x float> @faddqr_v4f32_y(<4 x float> %x, float %y, i32
; CHECK-NEXT: vctp.32 r0
; CHECK-NEXT: vdup.32 q1, r1
; CHECK-NEXT: vpst
-; CHECK-NEXT: vaddt.f32 q1, q1, q0
+; CHECK-NEXT: vaddt.f32 q1, q0, r1
; CHECK-NEXT: vmov q0, q1
; CHECK-NEXT: bx lr
entry:
@@ -2441,7 +2439,7 @@ define arm_aapcs_vfpcc <8 x half> @faddqr_v8f16_y(<8 x half> %x, half %y, i32 %n
; CHECK-NEXT: vctp.16 r0
; CHECK-NEXT: vdup.16 q1, r1
; CHECK-NEXT: vpst
-; CHECK-NEXT: vaddt.f16 q1, q1, q0
+; CHECK-NEXT: vaddt.f16 q1, q0, r1
; CHECK-NEXT: vmov q0, q1
; CHECK-NEXT: bx lr
entry:
More information about the llvm-commits
mailing list