[llvm] 007917b - [MVE] Fold fadd(select(..., +0.0)) into a predicated fadd
David Sherwood via llvm-commits
llvm-commits at lists.llvm.org
Fri Jun 10 03:10:50 PDT 2022
Author: David Sherwood
Date: 2022-06-10T11:09:55+01:00
New Revision: 007917b95ce2b569a8e5e90cd9b819676a6bb364
URL: https://github.com/llvm/llvm-project/commit/007917b95ce2b569a8e5e90cd9b819676a6bb364
DIFF: https://github.com/llvm/llvm-project/commit/007917b95ce2b569a8e5e90cd9b819676a6bb364.diff
LOG: [MVE] Fold fadd(select(..., +0.0)) into a predicated fadd
We already have patterns for matching fadd(select(..., -0.0)),
but an upcoming patch will lead to patterns using +0.0 as the
identity instead of -0.0. I'm adding support for these patterns
now to avoid any regressions for MVE.
Differential Revision: https://reviews.llvm.org/D127275
Added:
Modified:
llvm/lib/Target/ARM/ARMISelLowering.cpp
llvm/test/CodeGen/Thumb2/mve-pred-selectop3.ll
Removed:
################################################################################
diff --git a/llvm/lib/Target/ARM/ARMISelLowering.cpp b/llvm/lib/Target/ARM/ARMISelLowering.cpp
index 5f95acf35c889..8e9ff53985ecf 100644
--- a/llvm/lib/Target/ARM/ARMISelLowering.cpp
+++ b/llvm/lib/Target/ARM/ARMISelLowering.cpp
@@ -16704,14 +16704,16 @@ static SDValue PerformFAddVSelectCombine(SDNode *N, SelectionDAG &DAG,
EVT VT = N->getValueType(0);
SDLoc DL(N);
- // The identity element for a fadd is -0.0, which these VMOV's represent.
- auto isNegativeZeroSplat = [&](SDValue Op) {
+ // The identity element for a fadd is -0.0 or +0.0 when the nsz flag is set,
+ // which these VMOV's represent.
+ auto isIdentitySplat = [&](SDValue Op, bool NSZ) {
if (Op.getOpcode() != ISD::BITCAST ||
Op.getOperand(0).getOpcode() != ARMISD::VMOVIMM)
return false;
- if (VT == MVT::v4f32 && Op.getOperand(0).getConstantOperandVal(0) == 1664)
+ uint64_t ImmVal = Op.getOperand(0).getConstantOperandVal(0);
+ if (VT == MVT::v4f32 && (ImmVal == 1664 || (ImmVal == 0 && NSZ)))
return true;
- if (VT == MVT::v8f16 && Op.getOperand(0).getConstantOperandVal(0) == 2688)
+ if (VT == MVT::v8f16 && (ImmVal == 2688 || (ImmVal == 0 && NSZ)))
return true;
return false;
};
@@ -16719,12 +16721,17 @@ static SDValue PerformFAddVSelectCombine(SDNode *N, SelectionDAG &DAG,
if (Op0.getOpcode() == ISD::VSELECT && Op1.getOpcode() != ISD::VSELECT)
std::swap(Op0, Op1);
- if (Op1.getOpcode() != ISD::VSELECT ||
- !isNegativeZeroSplat(Op1.getOperand(2)))
+ if (Op1.getOpcode() != ISD::VSELECT)
return SDValue();
+
+ SDNodeFlags FaddFlags = N->getFlags();
+ bool NSZ = FaddFlags.hasNoSignedZeros();
+ if (!isIdentitySplat(Op1.getOperand(2), NSZ))
+ return SDValue();
+
SDValue FAdd =
- DAG.getNode(ISD::FADD, DL, VT, Op0, Op1.getOperand(1), N->getFlags());
- return DAG.getNode(ISD::VSELECT, DL, VT, Op1.getOperand(0), FAdd, Op0);
+ DAG.getNode(ISD::FADD, DL, VT, Op0, Op1.getOperand(1), FaddFlags);
+ return DAG.getNode(ISD::VSELECT, DL, VT, Op1.getOperand(0), FAdd, Op0, FaddFlags);
}
/// PerformVDIVCombine - VCVT (fixed-point to floating-point, Advanced SIMD)
diff --git a/llvm/test/CodeGen/Thumb2/mve-pred-selectop3.ll b/llvm/test/CodeGen/Thumb2/mve-pred-selectop3.ll
index e3e23f6524ba0..0773b65b5dfe0 100644
--- a/llvm/test/CodeGen/Thumb2/mve-pred-selectop3.ll
+++ b/llvm/test/CodeGen/Thumb2/mve-pred-selectop3.ll
@@ -363,6 +363,36 @@ entry:
ret <4 x float> %b
}
+define arm_aapcs_vfpcc <4 x float> @fadd_v4f32_x2(<4 x float> %x, <4 x float> %y, i32 %n) {
+; CHECK-LABEL: fadd_v4f32_x2:
+; CHECK: @ %bb.0: @ %entry
+; CHECK-NEXT: vmov.i32 q2, #0x0
+; CHECK-NEXT: vctp.32 r0
+; CHECK-NEXT: vpst
+; CHECK-NEXT: vmovt q2, q1
+; CHECK-NEXT: vadd.f32 q0, q2, q0
+; CHECK-NEXT: bx lr
+entry:
+ %c = call <4 x i1> @llvm.arm.mve.vctp32(i32 %n)
+ %a = select <4 x i1> %c, <4 x float> %y, <4 x float> <float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00>
+ %b = fadd <4 x float> %a, %x
+ ret <4 x float> %b
+}
+
+define arm_aapcs_vfpcc <4 x float> @fadd_v4f32_x3(<4 x float> %x, <4 x float> %y, i32 %n) {
+; CHECK-LABEL: fadd_v4f32_x3:
+; CHECK: @ %bb.0: @ %entry
+; CHECK-NEXT: vctp.32 r0
+; CHECK-NEXT: vpst
+; CHECK-NEXT: vaddt.f32 q0, q0, q1
+; CHECK-NEXT: bx lr
+entry:
+ %c = call <4 x i1> @llvm.arm.mve.vctp32(i32 %n)
+ %a = select <4 x i1> %c, <4 x float> %y, <4 x float> <float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00>
+ %b = fadd nsz <4 x float> %a, %x
+ ret <4 x float> %b
+}
+
define arm_aapcs_vfpcc <8 x half> @fadd_v8f16_x(<8 x half> %x, <8 x half> %y, i32 %n) {
; CHECK-LABEL: fadd_v8f16_x:
; CHECK: @ %bb.0: @ %entry
@@ -377,6 +407,36 @@ entry:
ret <8 x half> %b
}
+define arm_aapcs_vfpcc <8 x half> @fadd_v8f16_x2(<8 x half> %x, <8 x half> %y, i32 %n) {
+; CHECK-LABEL: fadd_v8f16_x2:
+; CHECK: @ %bb.0: @ %entry
+; CHECK-NEXT: vmov.i32 q2, #0x0
+; CHECK-NEXT: vctp.16 r0
+; CHECK-NEXT: vpst
+; CHECK-NEXT: vmovt q2, q1
+; CHECK-NEXT: vadd.f16 q0, q2, q0
+; CHECK-NEXT: bx lr
+entry:
+ %c = call <8 x i1> @llvm.arm.mve.vctp16(i32 %n)
+ %a = select <8 x i1> %c, <8 x half> %y, <8 x half> <half 0x0000, half 0x00000, half 0x00000, half 0x00000, half 0x00000, half 0x00000, half 0x00000, half 0x00000>
+ %b = fadd <8 x half> %a, %x
+ ret <8 x half> %b
+}
+
+define arm_aapcs_vfpcc <8 x half> @fadd_v8f16_x3(<8 x half> %x, <8 x half> %y, i32 %n) {
+; CHECK-LABEL: fadd_v8f16_x3:
+; CHECK: @ %bb.0: @ %entry
+; CHECK-NEXT: vctp.16 r0
+; CHECK-NEXT: vpst
+; CHECK-NEXT: vaddt.f16 q0, q0, q1
+; CHECK-NEXT: bx lr
+entry:
+ %c = call <8 x i1> @llvm.arm.mve.vctp16(i32 %n)
+ %a = select <8 x i1> %c, <8 x half> %y, <8 x half> <half 0x0000, half 0x00000, half 0x00000, half 0x00000, half 0x00000, half 0x00000, half 0x00000, half 0x00000>
+ %b = fadd nsz <8 x half> %a, %x
+ ret <8 x half> %b
+}
+
define arm_aapcs_vfpcc <4 x float> @fsub_v4f32_x(<4 x float> %x, <4 x float> %y, i32 %n) {
; CHECK-LABEL: fsub_v4f32_x:
; CHECK: @ %bb.0: @ %entry
More information about the llvm-commits
mailing list