[llvm] 007917b - [MVE] Fold fadd(select(..., +0.0)) into a predicated fadd

David Sherwood via llvm-commits llvm-commits at lists.llvm.org
Fri Jun 10 03:10:50 PDT 2022


Author: David Sherwood
Date: 2022-06-10T11:09:55+01:00
New Revision: 007917b95ce2b569a8e5e90cd9b819676a6bb364

URL: https://github.com/llvm/llvm-project/commit/007917b95ce2b569a8e5e90cd9b819676a6bb364
DIFF: https://github.com/llvm/llvm-project/commit/007917b95ce2b569a8e5e90cd9b819676a6bb364.diff

LOG: [MVE] Fold fadd(select(..., +0.0)) into a predicated fadd

We already have patterns for matching fadd(select(..., -0.0)),
but an upcoming patch will lead to patterns using +0.0 as the
identity instead of -0.0. I'm adding support for these patterns
now to avoid any regressions for MVE.

Differential Revision: https://reviews.llvm.org/D127275

Added: 
    

Modified: 
    llvm/lib/Target/ARM/ARMISelLowering.cpp
    llvm/test/CodeGen/Thumb2/mve-pred-selectop3.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/ARM/ARMISelLowering.cpp b/llvm/lib/Target/ARM/ARMISelLowering.cpp
index 5f95acf35c889..8e9ff53985ecf 100644
--- a/llvm/lib/Target/ARM/ARMISelLowering.cpp
+++ b/llvm/lib/Target/ARM/ARMISelLowering.cpp
@@ -16704,14 +16704,16 @@ static SDValue PerformFAddVSelectCombine(SDNode *N, SelectionDAG &DAG,
   EVT VT = N->getValueType(0);
   SDLoc DL(N);
 
-  // The identity element for a fadd is -0.0, which these VMOV's represent.
-  auto isNegativeZeroSplat = [&](SDValue Op) {
+  // The identity element for a fadd is -0.0 or +0.0 when the nsz flag is set,
+  // which these VMOV's represent.
+  auto isIdentitySplat = [&](SDValue Op, bool NSZ) {
     if (Op.getOpcode() != ISD::BITCAST ||
         Op.getOperand(0).getOpcode() != ARMISD::VMOVIMM)
       return false;
-    if (VT == MVT::v4f32 && Op.getOperand(0).getConstantOperandVal(0) == 1664)
+    uint64_t ImmVal = Op.getOperand(0).getConstantOperandVal(0);
+    if (VT == MVT::v4f32 && (ImmVal == 1664 || (ImmVal == 0 && NSZ)))
       return true;
-    if (VT == MVT::v8f16 && Op.getOperand(0).getConstantOperandVal(0) == 2688)
+    if (VT == MVT::v8f16 && (ImmVal == 2688 || (ImmVal == 0 && NSZ)))
       return true;
     return false;
   };
@@ -16719,12 +16721,17 @@ static SDValue PerformFAddVSelectCombine(SDNode *N, SelectionDAG &DAG,
   if (Op0.getOpcode() == ISD::VSELECT && Op1.getOpcode() != ISD::VSELECT)
     std::swap(Op0, Op1);
 
-  if (Op1.getOpcode() != ISD::VSELECT ||
-      !isNegativeZeroSplat(Op1.getOperand(2)))
+  if (Op1.getOpcode() != ISD::VSELECT)
     return SDValue();
+
+  SDNodeFlags FaddFlags = N->getFlags();
+  bool NSZ = FaddFlags.hasNoSignedZeros();
+  if (!isIdentitySplat(Op1.getOperand(2), NSZ))
+    return SDValue();
+
   SDValue FAdd =
-      DAG.getNode(ISD::FADD, DL, VT, Op0, Op1.getOperand(1), N->getFlags());
-  return DAG.getNode(ISD::VSELECT, DL, VT, Op1.getOperand(0), FAdd, Op0);
+      DAG.getNode(ISD::FADD, DL, VT, Op0, Op1.getOperand(1), FaddFlags);
+  return DAG.getNode(ISD::VSELECT, DL, VT, Op1.getOperand(0), FAdd, Op0, FaddFlags);
 }
 
 /// PerformVDIVCombine - VCVT (fixed-point to floating-point, Advanced SIMD)

diff  --git a/llvm/test/CodeGen/Thumb2/mve-pred-selectop3.ll b/llvm/test/CodeGen/Thumb2/mve-pred-selectop3.ll
index e3e23f6524ba0..0773b65b5dfe0 100644
--- a/llvm/test/CodeGen/Thumb2/mve-pred-selectop3.ll
+++ b/llvm/test/CodeGen/Thumb2/mve-pred-selectop3.ll
@@ -363,6 +363,36 @@ entry:
   ret <4 x float> %b
 }
 
+define arm_aapcs_vfpcc <4 x float> @fadd_v4f32_x2(<4 x float> %x, <4 x float> %y, i32 %n) {
+; CHECK-LABEL: fadd_v4f32_x2:
+; CHECK:       @ %bb.0: @ %entry
+; CHECK-NEXT:    vmov.i32 q2, #0x0
+; CHECK-NEXT:    vctp.32 r0
+; CHECK-NEXT:    vpst
+; CHECK-NEXT:    vmovt q2, q1
+; CHECK-NEXT:    vadd.f32 q0, q2, q0
+; CHECK-NEXT:    bx lr
+entry:
+  %c = call <4 x i1> @llvm.arm.mve.vctp32(i32 %n)
+  %a = select <4 x i1> %c, <4 x float> %y, <4 x float> <float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00>
+  %b = fadd <4 x float> %a, %x
+  ret <4 x float> %b
+}
+
+define arm_aapcs_vfpcc <4 x float> @fadd_v4f32_x3(<4 x float> %x, <4 x float> %y, i32 %n) {
+; CHECK-LABEL: fadd_v4f32_x3:
+; CHECK:       @ %bb.0: @ %entry
+; CHECK-NEXT:    vctp.32 r0
+; CHECK-NEXT:    vpst
+; CHECK-NEXT:    vaddt.f32 q0, q0, q1
+; CHECK-NEXT:    bx lr
+entry:
+  %c = call <4 x i1> @llvm.arm.mve.vctp32(i32 %n)
+  %a = select <4 x i1> %c, <4 x float> %y, <4 x float> <float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00>
+  %b = fadd nsz <4 x float> %a, %x
+  ret <4 x float> %b
+}
+
 define arm_aapcs_vfpcc <8 x half> @fadd_v8f16_x(<8 x half> %x, <8 x half> %y, i32 %n) {
 ; CHECK-LABEL: fadd_v8f16_x:
 ; CHECK:       @ %bb.0: @ %entry
@@ -377,6 +407,36 @@ entry:
   ret <8 x half> %b
 }
 
+define arm_aapcs_vfpcc <8 x half> @fadd_v8f16_x2(<8 x half> %x, <8 x half> %y, i32 %n) {
+; CHECK-LABEL: fadd_v8f16_x2:
+; CHECK:       @ %bb.0: @ %entry
+; CHECK-NEXT:    vmov.i32 q2, #0x0
+; CHECK-NEXT:    vctp.16 r0
+; CHECK-NEXT:    vpst
+; CHECK-NEXT:    vmovt q2, q1
+; CHECK-NEXT:    vadd.f16 q0, q2, q0
+; CHECK-NEXT:    bx lr
+entry:
+  %c = call <8 x i1> @llvm.arm.mve.vctp16(i32 %n)
+  %a = select <8 x i1> %c, <8 x half> %y, <8 x half> <half 0x0000, half 0x00000, half 0x00000, half 0x00000, half 0x00000, half 0x00000, half 0x00000, half 0x00000>
+  %b = fadd <8 x half> %a, %x
+  ret <8 x half> %b
+}
+
+define arm_aapcs_vfpcc <8 x half> @fadd_v8f16_x3(<8 x half> %x, <8 x half> %y, i32 %n) {
+; CHECK-LABEL: fadd_v8f16_x3:
+; CHECK:       @ %bb.0: @ %entry
+; CHECK-NEXT:    vctp.16 r0
+; CHECK-NEXT:    vpst
+; CHECK-NEXT:    vaddt.f16 q0, q0, q1
+; CHECK-NEXT:    bx lr
+entry:
+  %c = call <8 x i1> @llvm.arm.mve.vctp16(i32 %n)
+  %a = select <8 x i1> %c, <8 x half> %y, <8 x half> <half 0x0000, half 0x00000, half 0x00000, half 0x00000, half 0x00000, half 0x00000, half 0x00000, half 0x00000>
+  %b = fadd nsz <8 x half> %a, %x
+  ret <8 x half> %b
+}
+
 define arm_aapcs_vfpcc <4 x float> @fsub_v4f32_x(<4 x float> %x, <4 x float> %y, i32 %n) {
 ; CHECK-LABEL: fsub_v4f32_x:
 ; CHECK:       @ %bb.0: @ %entry


        


More information about the llvm-commits mailing list