[llvm] r280785 - AVX512F: FMA intrinsic + FNEG - sequence optimization

Elena Demikhovsky via llvm-commits llvm-commits at lists.llvm.org
Tue Sep 6 23:54:29 PDT 2016


Author: delena
Date: Wed Sep  7 01:54:28 2016
New Revision: 280785

URL: http://llvm.org/viewvc/llvm-project?rev=280785&view=rev
Log:
AVX512F: FMA intrinsic + FNEG - sequence optimization

The previous commit (r280368 - https://reviews.llvm.org/D23313) does not cover AVX-512F, KNL set.
FNEG(x) operation is lowered to (bitcast (vpxor (bitcast x), (bitcast constfp(0x80000000))).
It happens because FP XOR is not supported for 512-bit data types on KNL and we use integer XOR instead.
I added pattern match for integer XOR.

Differential Revision: https://reviews.llvm.org/D24221


Modified:
    llvm/trunk/lib/Target/X86/X86ISelLowering.cpp
    llvm/trunk/test/CodeGen/X86/fma-fneg-combine.ll

Modified: llvm/trunk/lib/Target/X86/X86ISelLowering.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Target/X86/X86ISelLowering.cpp?rev=280785&r1=280784&r2=280785&view=diff
==============================================================================
--- llvm/trunk/lib/Target/X86/X86ISelLowering.cpp (original)
+++ llvm/trunk/lib/Target/X86/X86ISelLowering.cpp Wed Sep  7 01:54:28 2016
@@ -29233,28 +29233,6 @@ static SDValue foldVectorXorShiftIntoCmp
   return DAG.getNode(X86ISD::PCMPGT, SDLoc(N), VT, Shift.getOperand(0), Ones);
 }
 
-static SDValue combineXor(SDNode *N, SelectionDAG &DAG,
-                                 TargetLowering::DAGCombinerInfo &DCI,
-                                 const X86Subtarget &Subtarget) {
-  if (SDValue Cmp = foldVectorXorShiftIntoCmp(N, DAG, Subtarget))
-    return Cmp;
-
-  if (DCI.isBeforeLegalizeOps())
-    return SDValue();
-
-  if (SDValue RV = foldXorTruncShiftIntoCmp(N, DAG))
-    return RV;
-
-  if (Subtarget.hasCMov())
-    if (SDValue RV = combineIntegerAbs(N, DAG))
-      return RV;
-
-  if (SDValue FPLogic = convertIntLogicToFPLogic(N, DAG, Subtarget))
-    return FPLogic;
-
-  return SDValue();
-}
-
 /// This function detects the AVG pattern between vectors of unsigned i8/i16,
 /// which is c = (a + b + 1) / 2, and replace this operation with the efficient
 /// X86ISD::AVG instruction.
@@ -30363,12 +30341,68 @@ static SDValue combineTruncate(SDNode *N
   return combineVectorTruncation(N, DAG, Subtarget);
 }
 
+/// Returns the negated value if the node \p N flips sign of FP value.
+///
+/// FP-negation node may have different forms: FNEG(x) or FXOR (x, 0x80000000).
+/// AVX512F does not have FXOR, so FNEG is lowered as
+/// (bitcast (xor (bitcast x), (bitcast ConstantFP(0x80000000)))).
+/// In this case we go though all bitcasts.
+static SDValue isFNEG(SDNode *N) {
+  if (N->getOpcode() == ISD::FNEG)
+    return N->getOperand(0);
+
+  SDValue Op = peekThroughBitcasts(SDValue(N, 0));
+  if (Op.getOpcode() != X86ISD::FXOR && Op.getOpcode() != ISD::XOR)
+    return SDValue();
+
+  SDValue Op1 = peekThroughBitcasts(Op.getOperand(1));
+  if (!Op1.getValueType().isFloatingPoint())
+    return SDValue();
+
+  SDValue Op0 = peekThroughBitcasts(Op.getOperand(0));
+
+  unsigned EltBits = Op1.getValueType().getScalarSizeInBits();
+  auto isSignBitValue = [&](const ConstantFP *C) {
+    return C->getValueAPF().bitcastToAPInt() == APInt::getSignBit(EltBits);
+  };
+
+  // There is more than one way to represent the same constant on
+  // the different X86 targets. The type of the node may also depend on size.
+  //  - load scalar value and broadcast
+  //  - BUILD_VECTOR node
+  //  - load from a constant pool.
+  // We check all variants here.
+  if (Op1.getOpcode() == X86ISD::VBROADCAST) {
+    if (auto *C = getTargetConstantFromNode(Op1.getOperand(0)))
+      if (isSignBitValue(cast<ConstantFP>(C)))
+        return Op0;
+
+  } else if (BuildVectorSDNode *BV = dyn_cast<BuildVectorSDNode>(Op1)) {
+    if (ConstantFPSDNode *CN = BV->getConstantFPSplatNode())
+      if (isSignBitValue(CN->getConstantFPValue()))
+        return Op0;
+
+  } else if (auto *C = getTargetConstantFromNode(Op1)) {
+    if (C->getType()->isVectorTy()) {
+      if (auto *SplatV = C->getSplatValue())
+        if (isSignBitValue(cast<ConstantFP>(SplatV)))
+          return Op0;
+    } else if (auto *FPConst = dyn_cast<ConstantFP>(C))
+      if (isSignBitValue(FPConst))
+        return Op0;
+  }
+  return SDValue();
+}
+
 /// Do target-specific dag combines on floating point negations.
 static SDValue combineFneg(SDNode *N, SelectionDAG &DAG,
                            const X86Subtarget &Subtarget) {
-  EVT VT = N->getValueType(0);
+  EVT OrigVT = N->getValueType(0);
+  SDValue Arg = isFNEG(N);
+  assert(Arg.getNode() && "N is expected to be an FNEG node");
+
+  EVT VT = Arg.getValueType();
   EVT SVT = VT.getScalarType();
-  SDValue Arg = N->getOperand(0);
   SDLoc DL(N);
 
   // Let legalize expand this if it isn't a legal type yet.
@@ -30381,40 +30415,30 @@ static SDValue combineFneg(SDNode *N, Se
   if (Arg.getOpcode() == ISD::FMUL && (SVT == MVT::f32 || SVT == MVT::f64) &&
       Arg->getFlags()->hasNoSignedZeros() && Subtarget.hasAnyFMA()) {
     SDValue Zero = DAG.getConstantFP(0.0, DL, VT);
-    return DAG.getNode(X86ISD::FNMSUB, DL, VT, Arg.getOperand(0),
-                       Arg.getOperand(1), Zero);
+    SDValue NewNode = DAG.getNode(X86ISD::FNMSUB, DL, VT, Arg.getOperand(0),
+                                  Arg.getOperand(1), Zero);
+    return DAG.getBitcast(OrigVT, NewNode);
   }
 
   // If we're negating a FMA node, then we can adjust the
   // instruction to include the extra negation.
+  unsigned NewOpcode = 0;
   if (Arg.hasOneUse()) {
     switch (Arg.getOpcode()) {
-    case X86ISD::FMADD:
-      return DAG.getNode(X86ISD::FNMSUB, DL, VT, Arg.getOperand(0),
-                         Arg.getOperand(1), Arg.getOperand(2));
-    case X86ISD::FMSUB:
-      return DAG.getNode(X86ISD::FNMADD, DL, VT, Arg.getOperand(0),
-                         Arg.getOperand(1), Arg.getOperand(2));
-    case X86ISD::FNMADD:
-      return DAG.getNode(X86ISD::FMSUB, DL, VT, Arg.getOperand(0),
-                         Arg.getOperand(1), Arg.getOperand(2));
-    case X86ISD::FNMSUB:
-      return DAG.getNode(X86ISD::FMADD, DL, VT, Arg.getOperand(0),
-                         Arg.getOperand(1), Arg.getOperand(2));
-    case X86ISD::FMADD_RND:
-      return DAG.getNode(X86ISD::FNMSUB_RND, DL, VT, Arg.getOperand(0),
-                         Arg.getOperand(1), Arg.getOperand(2), Arg.getOperand(3));
-    case X86ISD::FMSUB_RND:
-      return DAG.getNode(X86ISD::FNMADD_RND, DL, VT, Arg.getOperand(0),
-                         Arg.getOperand(1), Arg.getOperand(2), Arg.getOperand(3));
-    case X86ISD::FNMADD_RND:
-      return DAG.getNode(X86ISD::FMSUB_RND, DL, VT, Arg.getOperand(0),
-                         Arg.getOperand(1), Arg.getOperand(2), Arg.getOperand(3));
-    case X86ISD::FNMSUB_RND:
-      return DAG.getNode(X86ISD::FMADD_RND, DL, VT, Arg.getOperand(0),
-                         Arg.getOperand(1), Arg.getOperand(2), Arg.getOperand(3));
+    case X86ISD::FMADD:      NewOpcode = X86ISD::FNMSUB;     break;
+    case X86ISD::FMSUB:      NewOpcode = X86ISD::FNMADD;     break;
+    case X86ISD::FNMADD:     NewOpcode = X86ISD::FMSUB;      break;
+    case X86ISD::FNMSUB:     NewOpcode = X86ISD::FMADD;      break;
+    case X86ISD::FMADD_RND:  NewOpcode = X86ISD::FNMSUB_RND; break;
+    case X86ISD::FMSUB_RND:  NewOpcode = X86ISD::FNMADD_RND; break;
+    case X86ISD::FNMADD_RND: NewOpcode = X86ISD::FMSUB_RND;  break;
+    case X86ISD::FNMSUB_RND: NewOpcode = X86ISD::FMADD_RND;  break;
     }
   }
+  if (NewOpcode)
+    return DAG.getBitcast(OrigVT, DAG.getNode(NewOpcode, DL, VT,
+                                              Arg.getNode()->ops()));
+
   return SDValue();
 }
 
@@ -30442,42 +30466,28 @@ static SDValue lowerX86FPLogicOp(SDNode
   return SDValue();
 }
 
-/// Returns true if the node \p N is FNEG(x) or FXOR (x, 0x80000000).
-bool isFNEG(const SDNode *N) {
-  if (N->getOpcode() == ISD::FNEG)
-    return true;
+static SDValue combineXor(SDNode *N, SelectionDAG &DAG,
+                                 TargetLowering::DAGCombinerInfo &DCI,
+                                 const X86Subtarget &Subtarget) {
+  if (SDValue Cmp = foldVectorXorShiftIntoCmp(N, DAG, Subtarget))
+    return Cmp;
 
-  if (N->getOpcode() == X86ISD::FXOR) {
-    unsigned EltBits = N->getSimpleValueType(0).getScalarSizeInBits();
-    SDValue Op1 = N->getOperand(1);
-
-    auto isSignBitValue = [&](const ConstantFP *C) {
-      return C->getValueAPF().bitcastToAPInt() == APInt::getSignBit(EltBits);
-    };
-
-    // There is more than one way to represent the same constant on
-    // the different X86 targets. The type of the node may also depend on size.
-    //  - load scalar value and broadcast
-    //  - BUILD_VECTOR node
-    //  - load from a constant pool.
-    // We check all variants here.
-    if (Op1.getOpcode() == X86ISD::VBROADCAST) {
-      if (auto *C = getTargetConstantFromNode(Op1.getOperand(0)))
-        return isSignBitValue(cast<ConstantFP>(C));
-
-    } else if (BuildVectorSDNode *BV = dyn_cast<BuildVectorSDNode>(Op1)) {
-      if (ConstantFPSDNode *CN = BV->getConstantFPSplatNode())
-        return isSignBitValue(CN->getConstantFPValue());
-
-    } else if (auto *C = getTargetConstantFromNode(Op1)) {
-      if (C->getType()->isVectorTy()) {
-        if (auto *SplatV = C->getSplatValue())
-          return isSignBitValue(cast<ConstantFP>(SplatV));
-      } else if (auto *FPConst = dyn_cast<ConstantFP>(C))
-        return isSignBitValue(FPConst);
-    }
-  }
-  return false;
+  if (DCI.isBeforeLegalizeOps())
+    return SDValue();
+
+  if (SDValue RV = foldXorTruncShiftIntoCmp(N, DAG))
+    return RV;
+
+  if (Subtarget.hasCMov())
+    if (SDValue RV = combineIntegerAbs(N, DAG))
+      return RV;
+
+  if (SDValue FPLogic = convertIntLogicToFPLogic(N, DAG, Subtarget))
+    return FPLogic;
+
+  if (isFNEG(N))
+    return combineFneg(N, DAG, Subtarget);
+  return SDValue();
 }
 
 /// Do target-specific dag combines on X86ISD::FOR and X86ISD::FXOR nodes.
@@ -30907,18 +30917,20 @@ static SDValue combineFMA(SDNode *N, Sel
   SDValue B = N->getOperand(1);
   SDValue C = N->getOperand(2);
 
-  bool NegA = isFNEG(A.getNode());
-  bool NegB = isFNEG(B.getNode());
-  bool NegC = isFNEG(C.getNode());
+  auto invertIfNegative = [](SDValue &V) {
+    if (SDValue NegVal = isFNEG(V.getNode())) {
+      V = NegVal;
+      return true;
+    }
+    return false;
+  };
+
+  bool NegA = invertIfNegative(A);
+  bool NegB = invertIfNegative(B);
+  bool NegC = invertIfNegative(C);
 
   // Negative multiplication when NegA xor NegB
   bool NegMul = (NegA != NegB);
-  if (NegA)
-    A = A.getOperand(0);
-  if (NegB)
-    B = B.getOperand(0);
-  if (NegC)
-    C = C.getOperand(0);
 
   unsigned NewOpcode;
   if (!NegMul)

Modified: llvm/trunk/test/CodeGen/X86/fma-fneg-combine.ll
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/test/CodeGen/X86/fma-fneg-combine.ll?rev=280785&r1=280784&r2=280785&view=diff
==============================================================================
--- llvm/trunk/test/CodeGen/X86/fma-fneg-combine.ll (original)
+++ llvm/trunk/test/CodeGen/X86/fma-fneg-combine.ll Wed Sep  7 01:54:28 2016
@@ -1,5 +1,6 @@
 ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
-; RUN: llc < %s -mtriple=x86_64-unknown-linux-gnu -mattr=+avx512bw -mattr=+avx512vl -mattr=+avx512dq  | FileCheck %s
+; RUN: llc < %s -mtriple=x86_64-unknown-linux-gnu -mattr=+avx512bw -mattr=+avx512vl -mattr=+avx512dq  | FileCheck %s  --check-prefix=CHECK --check-prefix=SKX
+; RUN: llc < %s -mtriple=x86_64-unknown-linux-gnu -mattr=+avx512f -mattr=+fma | FileCheck %s --check-prefix=CHECK --check-prefix=KNL
 
 ; This test checks combinations of FNEG and FMA intrinsics on AVX-512 target
 ; PR28892
@@ -88,11 +89,18 @@ entry:
 }
 
 define <8 x float> @test8(<8 x float> %a, <8 x float> %b, <8 x float> %c) {
-; CHECK-LABEL: test8:
-; CHECK:       # BB#0: # %entry
-; CHECK-NEXT:    vxorps {{.*}}(%rip){1to8}, %ymm2, %ymm2
-; CHECK-NEXT:    vfmsub213ps %ymm2, %ymm1, %ymm0
-; CHECK-NEXT:    retq
+; SKX-LABEL: test8:
+; SKX:       # BB#0: # %entry
+; SKX-NEXT:    vxorps {{.*}}(%rip){1to8}, %ymm2, %ymm2
+; SKX-NEXT:    vfmsub213ps %ymm2, %ymm1, %ymm0
+; SKX-NEXT:    retq
+;
+; KNL-LABEL: test8:
+; KNL:       # BB#0: # %entry
+; KNL-NEXT:    vbroadcastss {{.*}}(%rip), %ymm3
+; KNL-NEXT:    vxorps %ymm3, %ymm2, %ymm2
+; KNL-NEXT:    vfmsub213ps %ymm2, %ymm1, %ymm0
+; KNL-NEXT:    retq
 entry:
   %sub.c = fsub <8 x float> <float -0.000000e+00, float -0.000000e+00, float -0.000000e+00, float -0.000000e+00, float -0.000000e+00, float -0.000000e+00, float -0.000000e+00, float -0.000000e+00>, %c
   %0 = tail call <8 x float> @llvm.x86.fma.vfmsub.ps.256(<8 x float> %a, <8 x float> %b, <8 x float> %sub.c) #2
@@ -115,22 +123,9 @@ entry:
 
 declare <8 x double> @llvm.x86.avx512.mask.vfmadd.pd.512(<8 x double> %a, <8 x double> %b, <8 x double> %c, i8, i32)
 
-define <4 x double> @test10(<4 x double> %a, <4 x double> %b, <4 x double> %c) {
+define <2 x double> @test10(<2 x double> %a, <2 x double> %b, <2 x double> %c) {
 ; CHECK-LABEL: test10:
 ; CHECK:       # BB#0: # %entry
-; CHECK-NEXT:    vfnmsub213pd %ymm2, %ymm1, %ymm0
-; CHECK-NEXT:    retq
-entry:
-  %0 = tail call <4 x double> @llvm.x86.avx512.mask.vfmadd.pd.256(<4 x double> %a, <4 x double> %b, <4 x double> %c, i8 -1) #2
-  %sub.i = fsub <4 x double> <double -0.000000e+00, double -0.000000e+00, double -0.000000e+00, double -0.000000e+00>, %0
-  ret <4 x double> %sub.i
-}
-
-declare <4 x double> @llvm.x86.avx512.mask.vfmadd.pd.256(<4 x double> %a, <4 x double> %b, <4 x double> %c, i8)
-
-define <2 x double> @test11(<2 x double> %a, <2 x double> %b, <2 x double> %c) {
-; CHECK-LABEL: test11:
-; CHECK:       # BB#0: # %entry
 ; CHECK-NEXT:    vfnmsub213sd %xmm2, %xmm0, %xmm1
 ; CHECK-NEXT:    vmovaps %xmm1, %xmm0
 ; CHECK-NEXT:    retq




More information about the llvm-commits mailing list