[llvm] r339043 - [X86] Recognize a splat of negate in isFNEG

Easwaran Raman via llvm-commits llvm-commits at lists.llvm.org
Mon Aug 6 12:23:38 PDT 2018


Author: eraman
Date: Mon Aug  6 12:23:38 2018
New Revision: 339043

URL: http://llvm.org/viewvc/llvm-project?rev=339043&view=rev
Log:
[X86] Recognize a splat of negate in isFNEG

Summary:
Expand isFNEG so that we generate the appropriate F(N)M(ADD|SUB)
instructions in more cases. For example, the following sequence
a = _mm256_broadcast_ss(f)
d = _mm256_fnmadd_ps(a, b, c)

generates an fsub and fma without this patch and an fnma with this
change.

Reviewers: craig.topper

Subscribers: llvm-commits, davidxl, wmi

Differential Revision: https://reviews.llvm.org/D48467

Modified:
    llvm/trunk/lib/Target/X86/X86ISelLowering.cpp
    llvm/trunk/test/CodeGen/X86/avx2-fma-fneg-combine.ll

Modified: llvm/trunk/lib/Target/X86/X86ISelLowering.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Target/X86/X86ISelLowering.cpp?rev=339043&r1=339042&r2=339043&view=diff
==============================================================================
--- llvm/trunk/lib/Target/X86/X86ISelLowering.cpp (original)
+++ llvm/trunk/lib/Target/X86/X86ISelLowering.cpp Mon Aug  6 12:23:38 2018
@@ -5633,6 +5633,24 @@ static bool getTargetConstantBitsFromNod
     }
     return CastBitData(UndefSrcElts, SrcEltBits);
   }
+  if (ISD::isBuildVectorOfConstantFPSDNodes(Op.getNode())) {
+    unsigned SrcEltSizeInBits = VT.getScalarSizeInBits();
+    unsigned NumSrcElts = SizeInBits / SrcEltSizeInBits;
+
+    APInt UndefSrcElts(NumSrcElts, 0);
+    SmallVector<APInt, 64> SrcEltBits(NumSrcElts, APInt(SrcEltSizeInBits, 0));
+    for (unsigned i = 0, e = Op.getNumOperands(); i != e; ++i) {
+      const SDValue &Src = Op.getOperand(i);
+      if (Src.isUndef()) {
+        UndefSrcElts.setBit(i);
+        continue;
+      }
+      auto *Cst = cast<ConstantFPSDNode>(Src);
+      APInt RawBits = Cst->getValueAPF().bitcastToAPInt();
+      SrcEltBits[i] = RawBits.zextOrTrunc(SrcEltSizeInBits);
+    }
+    return CastBitData(UndefSrcElts, SrcEltBits);
+  }
 
   // Extract constant bits from constant pool vector.
   if (auto *Cst = getTargetConstantFromNode(Op)) {
@@ -36971,29 +36989,72 @@ static SDValue combineTruncate(SDNode *N
 
 /// Returns the negated value if the node \p N flips sign of FP value.
 ///
-/// FP-negation node may have different forms: FNEG(x) or FXOR (x, 0x80000000).
+/// FP-negation node may have different forms: FNEG(x), FXOR (x, 0x80000000)
+/// or FSUB(0, x)
 /// AVX512F does not have FXOR, so FNEG is lowered as
 /// (bitcast (xor (bitcast x), (bitcast ConstantFP(0x80000000)))).
 /// In this case we go though all bitcasts.
-static SDValue isFNEG(SDNode *N) {
+/// This also recognizes splat of a negated value and returns the splat of that
+/// value.
+static SDValue isFNEG(SelectionDAG &DAG, SDNode *N) {
   if (N->getOpcode() == ISD::FNEG)
     return N->getOperand(0);
 
   SDValue Op = peekThroughBitcasts(SDValue(N, 0));
-  if (Op.getOpcode() != X86ISD::FXOR && Op.getOpcode() != ISD::XOR)
+  auto VT = Op->getValueType(0);
+  if (auto SVOp = dyn_cast<ShuffleVectorSDNode>(Op.getNode())) {
+    // For a VECTOR_SHUFFLE(VEC1, VEC2), if the VEC2 is undef, then the negate
+    // of this is VECTOR_SHUFFLE(-VEC1, UNDEF).  The mask can be anything here.
+    if (!SVOp->getOperand(1).isUndef())
+      return SDValue();
+    if (SDValue NegOp0 = isFNEG(DAG, SVOp->getOperand(0).getNode()))
+      return DAG.getVectorShuffle(VT, SDLoc(SVOp), NegOp0, DAG.getUNDEF(VT),
+                                  SVOp->getMask());
+    return SDValue();
+  }
+  unsigned Opc = Op.getOpcode();
+  if (Opc == ISD::INSERT_VECTOR_ELT) {
+    // Negate of INSERT_VECTOR_ELT(UNDEF, V, INDEX) is INSERT_VECTOR_ELT(UNDEF,
+    // -V, INDEX).
+    SDValue InsVector = Op.getOperand(0);
+    SDValue InsVal = Op.getOperand(1);
+    if (!InsVector.isUndef())
+      return SDValue();
+    if (SDValue NegInsVal = isFNEG(DAG, InsVal.getNode()))
+      return DAG.getNode(ISD::INSERT_VECTOR_ELT, SDLoc(Op), VT, InsVector,
+                         NegInsVal, Op.getOperand(2));
+    return SDValue();
+  }
+
+  if (Opc != X86ISD::FXOR && Opc != ISD::XOR && Opc != ISD::FSUB)
     return SDValue();
 
   SDValue Op1 = peekThroughBitcasts(Op.getOperand(1));
   if (!Op1.getValueType().isFloatingPoint())
     return SDValue();
 
-  // Extract constant bits and see if they are all sign bit masks.
+  SDValue Op0 = peekThroughBitcasts(Op.getOperand(0));
+
+  // For XOR and FXOR, we want to check if constant bits of Op1 are sign bit
+  // masks. For FSUB, we have to check if constant bits of Op0 are sign bit
+  // masks and hence we swap the operands.
+  if (Opc == ISD::FSUB)
+    std::swap(Op0, Op1);
+
   APInt UndefElts;
   SmallVector<APInt, 16> EltBits;
+  // Extract constant bits and see if they are all sign bit masks. Ignore the
+  // undef elements.
   if (getTargetConstantBitsFromNode(Op1, Op1.getScalarValueSizeInBits(),
-                                    UndefElts, EltBits, false, false))
-    if (llvm::all_of(EltBits, [](APInt &I) { return I.isSignMask(); }))
-      return peekThroughBitcasts(Op.getOperand(0));
+                                    UndefElts, EltBits,
+                                    /* AllowWholeUndefs */ true,
+                                    /* AllowPartialUndefs */ false)) {
+    for (unsigned I = 0, E = EltBits.size(); I < E; I++)
+      if (!UndefElts[I] && !EltBits[I].isSignMask())
+        return SDValue();
+
+    return peekThroughBitcasts(Op0);
+  }
 
   return SDValue();
 }
@@ -37002,8 +37063,9 @@ static SDValue isFNEG(SDNode *N) {
 static SDValue combineFneg(SDNode *N, SelectionDAG &DAG,
                            const X86Subtarget &Subtarget) {
   EVT OrigVT = N->getValueType(0);
-  SDValue Arg = isFNEG(N);
-  assert(Arg.getNode() && "N is expected to be an FNEG node");
+  SDValue Arg = isFNEG(DAG, N);
+  if (!Arg)
+    return SDValue();
 
   EVT VT = Arg.getValueType();
   EVT SVT = VT.getScalarType();
@@ -37118,9 +37180,7 @@ static SDValue combineXor(SDNode *N, Sel
   if (SDValue FPLogic = convertIntLogicToFPLogic(N, DAG, Subtarget))
     return FPLogic;
 
-  if (isFNEG(N))
-    return combineFneg(N, DAG, Subtarget);
-  return SDValue();
+  return combineFneg(N, DAG, Subtarget);
 }
 
 static SDValue combineBEXTR(SDNode *N, SelectionDAG &DAG,
@@ -37253,9 +37313,8 @@ static SDValue combineFOr(SDNode *N, Sel
   if (isNullFPScalarOrVectorConst(N->getOperand(1)))
     return N->getOperand(0);
 
-  if (isFNEG(N))
-    if (SDValue NewVal = combineFneg(N, DAG, Subtarget))
-      return NewVal;
+  if (SDValue NewVal = combineFneg(N, DAG, Subtarget))
+    return NewVal;
 
   return lowerX86FPLogicOp(N, DAG, Subtarget);
 }
@@ -37940,7 +37999,7 @@ static SDValue combineFMA(SDNode *N, Sel
   SDValue C = N->getOperand(2);
 
   auto invertIfNegative = [&DAG](SDValue &V) {
-    if (SDValue NegVal = isFNEG(V.getNode())) {
+    if (SDValue NegVal = isFNEG(DAG, V.getNode())) {
       V = DAG.getBitcast(V.getValueType(), NegVal);
       return true;
     }
@@ -37948,7 +38007,7 @@ static SDValue combineFMA(SDNode *N, Sel
     // new extract from the FNEG input.
     if (V.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
         isNullConstant(V.getOperand(1))) {
-      if (SDValue NegVal = isFNEG(V.getOperand(0).getNode())) {
+      if (SDValue NegVal = isFNEG(DAG, V.getOperand(0).getNode())) {
         NegVal = DAG.getBitcast(V.getOperand(0).getValueType(), NegVal);
         V = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SDLoc(V), V.getValueType(),
                         NegVal, V.getOperand(1));
@@ -37981,7 +38040,7 @@ static SDValue combineFMADDSUB(SDNode *N
   SDLoc dl(N);
   EVT VT = N->getValueType(0);
 
-  SDValue NegVal = isFNEG(N->getOperand(2).getNode());
+  SDValue NegVal = isFNEG(DAG, N->getOperand(2).getNode());
   if (!NegVal)
     return SDValue();
 

Modified: llvm/trunk/test/CodeGen/X86/avx2-fma-fneg-combine.ll
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/test/CodeGen/X86/avx2-fma-fneg-combine.ll?rev=339043&r1=339042&r2=339043&view=diff
==============================================================================
--- llvm/trunk/test/CodeGen/X86/avx2-fma-fneg-combine.ll (original)
+++ llvm/trunk/test/CodeGen/X86/avx2-fma-fneg-combine.ll Mon Aug  6 12:23:38 2018
@@ -118,20 +118,14 @@ declare <2 x double> @llvm.x86.fma.vfmad
 define <8 x float> @test7(float %a, <8 x float> %b, <8 x float> %c)  {
 ; X32-LABEL: test7:
 ; X32:       # %bb.0: # %entry
-; X32-NEXT:    vmovss {{.*#+}} xmm2 = mem[0],zero,zero,zero
-; X32-NEXT:    vmovss {{.*#+}} xmm3 = mem[0],zero,zero,zero
-; X32-NEXT:    vsubps %ymm2, %ymm3, %ymm2
-; X32-NEXT:    vbroadcastss %xmm2, %ymm2
-; X32-NEXT:    vfmadd213ps {{.*#+}} ymm0 = (ymm2 * ymm0) + ymm1
+; X32-NEXT:    vbroadcastss {{[0-9]+}}(%esp), %ymm2
+; X32-NEXT:    vfnmadd213ps {{.*#+}} ymm0 = -(ymm2 * ymm0) + ymm1
 ; X32-NEXT:    retl
 ;
 ; X64-LABEL: test7:
 ; X64:       # %bb.0: # %entry
-; X64-NEXT:    # kill: def $xmm0 killed $xmm0 def $ymm0
-; X64-NEXT:    vmovss {{.*#+}} xmm3 = mem[0],zero,zero,zero
-; X64-NEXT:    vsubps %ymm0, %ymm3, %ymm0
 ; X64-NEXT:    vbroadcastss %xmm0, %ymm0
-; X64-NEXT:    vfmadd213ps {{.*#+}} ymm0 = (ymm1 * ymm0) + ymm2
+; X64-NEXT:    vfnmadd213ps {{.*#+}} ymm0 = -(ymm1 * ymm0) + ymm2
 ; X64-NEXT:    retq
 entry:
   %0 = insertelement <8 x float> undef, float %a, i32 0
@@ -145,19 +139,14 @@ entry:
 define <8 x float> @test8(float %a, <8 x float> %b, <8 x float> %c)  {
 ; X32-LABEL: test8:
 ; X32:       # %bb.0: # %entry
-; X32-NEXT:    vmovss {{.*#+}} xmm2 = mem[0],zero,zero,zero
-; X32-NEXT:    vbroadcastss {{.*#+}} xmm3 = [-0,-0,-0,-0]
-; X32-NEXT:    vxorps %xmm3, %xmm2, %xmm2
-; X32-NEXT:    vbroadcastss %xmm2, %ymm2
-; X32-NEXT:    vfmadd213ps {{.*#+}} ymm0 = (ymm2 * ymm0) + ymm1
+; X32-NEXT:    vbroadcastss {{[0-9]+}}(%esp), %ymm2
+; X32-NEXT:    vfnmadd213ps {{.*#+}} ymm0 = -(ymm2 * ymm0) + ymm1
 ; X32-NEXT:    retl
 ;
 ; X64-LABEL: test8:
 ; X64:       # %bb.0: # %entry
-; X64-NEXT:    vbroadcastss {{.*#+}} xmm3 = [-0,-0,-0,-0]
-; X64-NEXT:    vxorps %xmm3, %xmm0, %xmm0
 ; X64-NEXT:    vbroadcastss %xmm0, %ymm0
-; X64-NEXT:    vfmadd213ps {{.*#+}} ymm0 = (ymm1 * ymm0) + ymm2
+; X64-NEXT:    vfnmadd213ps {{.*#+}} ymm0 = -(ymm1 * ymm0) + ymm2
 ; X64-NEXT:    retq
 entry:
   %0 = fsub float -0.0, %a




More information about the llvm-commits mailing list