[llvm] 728cf6d - Revert "[DAGCombine] Remove the getNegatibleCost to avoid the out of sync with getNegatedExpression"

Sam McCall via llvm-commits llvm-commits at lists.llvm.org
Mon May 11 07:44:44 PDT 2020


Author: Sam McCall
Date: 2020-05-11T16:44:01+02:00
New Revision: 728cf6d86b4f20144ac10517afb0cb978beac124

URL: https://github.com/llvm/llvm-project/commit/728cf6d86b4f20144ac10517afb0cb978beac124
DIFF: https://github.com/llvm/llvm-project/commit/728cf6d86b4f20144ac10517afb0cb978beac124.diff

LOG: Revert "[DAGCombine] Remove the getNegatibleCost to avoid the out of sync with getNegatedExpression"

This reverts commit 3c44c441db0f8d7e210806b5b221cd9ed66f2d7b.

Causes infloops on some inputs, see https://reviews.llvm.org/D77319 for repro

Added: 
    

Modified: 
    llvm/include/llvm/CodeGen/TargetLowering.h
    llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
    llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
    llvm/lib/Target/AMDGPU/AMDGPUISelLowering.h
    llvm/lib/Target/X86/X86ISelLowering.cpp
    llvm/lib/Target/X86/X86ISelLowering.h
    llvm/test/CodeGen/X86/neg_fp.ll

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index e647537ff245..811aba2b443c 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -3535,24 +3535,37 @@ class TargetLowering : public TargetLoweringBase {
     llvm_unreachable("Not Implemented");
   }
 
+  /// Returns whether computing the negated form of the specified expression is
+  /// more expensive, the same cost or cheaper.
+  virtual NegatibleCost getNegatibleCost(SDValue Op, SelectionDAG &DAG,
+                                         bool LegalOperations, bool ForCodeSize,
+                                         unsigned Depth = 0) const;
+
+  /// If getNegatibleCost returns Neutral/Cheaper, return the newly negated
+  /// expression.
+  virtual SDValue negateExpression(SDValue Op, SelectionDAG &DAG, bool LegalOps,
+                                   bool OptForSize, unsigned Depth = 0) const;
+
   /// Return the newly negated expression if the cost is not expensive and
   /// set the cost in \p Cost to indicate that if it is cheaper or neutral to
   /// do the negation.
-  virtual SDValue getNegatedExpression(SDValue Op, SelectionDAG &DAG,
-                                       bool LegalOps, bool OptForSize,
-                                       NegatibleCost &Cost,
-                                       unsigned Depth = 0) const;
+  SDValue getNegatedExpression(SDValue Op, SelectionDAG &DAG, bool LegalOps,
+                               bool OptForSize, NegatibleCost &Cost,
+                               unsigned Depth = 0) const {
+    Cost = getNegatibleCost(Op, DAG, LegalOps, OptForSize, Depth);
+    if (Cost != NegatibleCost::Expensive)
+      return negateExpression(Op, DAG, LegalOps, OptForSize, Depth);
+    return SDValue();
+  }
 
   /// This is the helper function to return the newly negated expression only
   /// when the cost is cheaper.
   SDValue getCheaperNegatedExpression(SDValue Op, SelectionDAG &DAG,
                                       bool LegalOps, bool OptForSize,
                                       unsigned Depth = 0) const {
-    NegatibleCost Cost = NegatibleCost::Expensive;
-    SDValue Neg =
-        getNegatedExpression(Op, DAG, LegalOps, OptForSize, Cost, Depth);
-    if (Neg && Cost == NegatibleCost::Cheaper)
-      return Neg;
+    if (getNegatibleCost(Op, DAG, LegalOps, OptForSize, Depth) ==
+        NegatibleCost::Cheaper)
+      return negateExpression(Op, DAG, LegalOps, OptForSize, Depth);
     return SDValue();
   }
 

diff  --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index 54f99b229043..e0f7040165d5 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -5583,79 +5583,165 @@ verifyReturnAddressArgumentIsConstant(SDValue Op, SelectionDAG &DAG) const {
   return false;
 }
 
-SDValue TargetLowering::getNegatedExpression(SDValue Op, SelectionDAG &DAG,
-                                             bool LegalOps, bool OptForSize,
-                                             NegatibleCost &Cost,
-                                             unsigned Depth) const {
+TargetLowering::NegatibleCost
+TargetLowering::getNegatibleCost(SDValue Op, SelectionDAG &DAG,
+                                 bool LegalOperations, bool ForCodeSize,
+                                 unsigned Depth) const {
   // fneg is removable even if it has multiple uses.
-  if (Op.getOpcode() == ISD::FNEG) {
-    Cost = NegatibleCost::Cheaper;
-    return Op.getOperand(0);
-  }
+  if (Op.getOpcode() == ISD::FNEG)
+    return NegatibleCost::Cheaper;
 
-  // Don't recurse exponentially.
-  if (Depth > SelectionDAG::MaxRecursionDepth)
-    return SDValue();
-
-  // Pre-increment recursion depth for use in recursive calls.
-  ++Depth;
+  // Don't allow anything with multiple uses unless we know it is free.
+  EVT VT = Op.getValueType();
   const SDNodeFlags Flags = Op->getFlags();
   const TargetOptions &Options = DAG.getTarget().Options;
-  EVT VT = Op.getValueType();
-  unsigned Opcode = Op.getOpcode();
-
-  // Don't allow anything with multiple uses unless we know it is free.
-  if (!Op.hasOneUse() && Opcode != ISD::ConstantFP) {
-    bool IsFreeExtend = Opcode == ISD::FP_EXTEND &&
+  if (!Op.hasOneUse()) {
+    bool IsFreeExtend = Op.getOpcode() == ISD::FP_EXTEND &&
                         isFPExtFree(VT, Op.getOperand(0).getValueType());
-    if (!IsFreeExtend)
-      return SDValue();
+
+    // If we already have the use of the negated floating constant, it is free
+    // to negate it even it has multiple uses.
+    bool IsFreeConstant =
+        Op.getOpcode() == ISD::ConstantFP &&
+        !negateExpression(Op, DAG, LegalOperations, ForCodeSize).use_empty();
+
+    if (!IsFreeExtend && !IsFreeConstant)
+      return NegatibleCost::Expensive;
   }
 
-  SDLoc DL(Op);
+  // Don't recurse exponentially.
+  if (Depth > SelectionDAG::MaxRecursionDepth)
+    return NegatibleCost::Expensive;
 
-  switch (Opcode) {
+  switch (Op.getOpcode()) {
   case ISD::ConstantFP: {
+    if (!LegalOperations)
+      return NegatibleCost::Neutral;
+
     // Don't invert constant FP values after legalization unless the target says
     // the negated constant is legal.
-    bool IsOpLegal =
-        isOperationLegal(ISD::ConstantFP, VT) ||
+    if (isOperationLegal(ISD::ConstantFP, VT) ||
         isFPImmLegal(neg(cast<ConstantFPSDNode>(Op)->getValueAPF()), VT,
-                     OptForSize);
-
-    if (LegalOps && !IsOpLegal)
-      break;
-
-    APFloat V = cast<ConstantFPSDNode>(Op)->getValueAPF();
-    V.changeSign();
-    SDValue CFP = DAG.getConstantFP(V, DL, VT);
-
-    // If we already have the use of the negated floating constant, it is free
-    // to negate it even it has multiple uses.
-    if (!Op.hasOneUse() && CFP.use_empty())
-      break;
-    Cost = NegatibleCost::Neutral;
-    return CFP;
+                     ForCodeSize))
+      return NegatibleCost::Neutral;
+    break;
   }
   case ISD::BUILD_VECTOR: {
     // Only permit BUILD_VECTOR of constants.
     if (llvm::any_of(Op->op_values(), [&](SDValue N) {
           return !N.isUndef() && !isa<ConstantFPSDNode>(N);
         }))
-      break;
-
-    bool IsOpLegal =
-        (isOperationLegal(ISD::ConstantFP, VT) &&
-         isOperationLegal(ISD::BUILD_VECTOR, VT)) ||
-        llvm::all_of(Op->op_values(), [&](SDValue N) {
+      return NegatibleCost::Expensive;
+    if (!LegalOperations)
+      return NegatibleCost::Neutral;
+    if (isOperationLegal(ISD::ConstantFP, VT) &&
+        isOperationLegal(ISD::BUILD_VECTOR, VT))
+      return NegatibleCost::Neutral;
+    if (llvm::all_of(Op->op_values(), [&](SDValue N) {
           return N.isUndef() ||
                  isFPImmLegal(neg(cast<ConstantFPSDNode>(N)->getValueAPF()), VT,
-                              OptForSize);
-        });
+                              ForCodeSize);
+        }))
+      return NegatibleCost::Neutral;
+    break;
+  }
+  case ISD::FADD: {
+    if (!Options.NoSignedZerosFPMath && !Flags.hasNoSignedZeros())
+      return NegatibleCost::Expensive;
 
-    if (LegalOps && !IsOpLegal)
-      break;
+    // After operation legalization, it might not be legal to create new FSUBs.
+    if (LegalOperations && !isOperationLegalOrCustom(ISD::FSUB, VT))
+      return NegatibleCost::Expensive;
+
+    // fold (fneg (fadd A, B)) -> (fsub (fneg A), B)
+    NegatibleCost V0 = getNegatibleCost(Op.getOperand(0), DAG, LegalOperations,
+                                        ForCodeSize, Depth + 1);
+    if (V0 != NegatibleCost::Expensive)
+      return V0;
+    // fold (fneg (fadd A, B)) -> (fsub (fneg B), A)
+    return getNegatibleCost(Op.getOperand(1), DAG, LegalOperations, ForCodeSize,
+                            Depth + 1);
+  }
+  case ISD::FSUB:
+    // We can't turn -(A-B) into B-A when we honor signed zeros.
+    if (!Options.NoSignedZerosFPMath && !Flags.hasNoSignedZeros())
+      return NegatibleCost::Expensive;
+
+    // fold (fneg (fsub A, B)) -> (fsub B, A)
+    return NegatibleCost::Neutral;
+  case ISD::FMUL:
+  case ISD::FDIV: {
+    // fold (fneg (fmul X, Y)) -> (fmul (fneg X), Y) or (fmul X, (fneg Y))
+    NegatibleCost V0 = getNegatibleCost(Op.getOperand(0), DAG, LegalOperations,
+                                        ForCodeSize, Depth + 1);
+    if (V0 != NegatibleCost::Expensive)
+      return V0;
+
+    // Ignore X * 2.0 because that is expected to be canonicalized to X + X.
+    if (auto *C = isConstOrConstSplatFP(Op.getOperand(1)))
+      if (C->isExactlyValue(2.0) && Op.getOpcode() == ISD::FMUL)
+        return NegatibleCost::Expensive;
+
+    return getNegatibleCost(Op.getOperand(1), DAG, LegalOperations, ForCodeSize,
+                            Depth + 1);
+  }
+  case ISD::FMA:
+  case ISD::FMAD: {
+    if (!Options.NoSignedZerosFPMath && !Flags.hasNoSignedZeros())
+      return NegatibleCost::Expensive;
+
+    // fold (fneg (fma X, Y, Z)) -> (fma (fneg X), Y, (fneg Z))
+    // fold (fneg (fma X, Y, Z)) -> (fma X, (fneg Y), (fneg Z))
+    NegatibleCost V2 = getNegatibleCost(Op.getOperand(2), DAG, LegalOperations,
+                                        ForCodeSize, Depth + 1);
+    if (NegatibleCost::Expensive == V2)
+      return NegatibleCost::Expensive;
+
+    // One of Op0/Op1 must be cheaply negatible, then select the cheapest.
+    NegatibleCost V0 = getNegatibleCost(Op.getOperand(0), DAG, LegalOperations,
+                                        ForCodeSize, Depth + 1);
+    NegatibleCost V1 = getNegatibleCost(Op.getOperand(1), DAG, LegalOperations,
+                                        ForCodeSize, Depth + 1);
+    NegatibleCost V01 = std::min(V0, V1);
+    if (V01 == NegatibleCost::Expensive)
+      return NegatibleCost::Expensive;
+    return std::min(V01, V2);
+  }
 
+  case ISD::FP_EXTEND:
+  case ISD::FP_ROUND:
+  case ISD::FSIN:
+    return getNegatibleCost(Op.getOperand(0), DAG, LegalOperations, ForCodeSize,
+                            Depth + 1);
+  }
+
+  return NegatibleCost::Expensive;
+}
+
+SDValue TargetLowering::negateExpression(SDValue Op, SelectionDAG &DAG,
+                                         bool LegalOps, bool OptForSize,
+                                         unsigned Depth) const {
+  // fneg is removable even if it has multiple uses.
+  if (Op.getOpcode() == ISD::FNEG)
+    return Op.getOperand(0);
+
+  assert(Depth <= SelectionDAG::MaxRecursionDepth &&
+         "negateExpression doesn't match getNegatibleCost");
+
+  // Pre-increment recursion depth for use in recursive calls.
+  ++Depth;
+  const SDNodeFlags Flags = Op->getFlags();
+  EVT VT = Op.getValueType();
+  unsigned Opcode = Op.getOpcode();
+  SDLoc DL(Op);
+
+  switch (Opcode) {
+  case ISD::ConstantFP: {
+    APFloat V = cast<ConstantFPSDNode>(Op)->getValueAPF();
+    V.changeSign();
+    return DAG.getConstantFP(V, DL, VT);
+  }
+  case ISD::BUILD_VECTOR: {
     SmallVector<SDValue, 4> Ops;
     for (SDValue C : Op->op_values()) {
       if (C.isUndef()) {
@@ -5666,138 +5752,85 @@ SDValue TargetLowering::getNegatedExpression(SDValue Op, SelectionDAG &DAG,
       V.changeSign();
       Ops.push_back(DAG.getConstantFP(V, DL, C.getValueType()));
     }
-    Cost = NegatibleCost::Neutral;
     return DAG.getBuildVector(VT, DL, Ops);
   }
   case ISD::FADD: {
-    if (!Options.NoSignedZerosFPMath && !Flags.hasNoSignedZeros())
-      break;
-
-    // After operation legalization, it might not be legal to create new FSUBs.
-    if (LegalOps && !isOperationLegalOrCustom(ISD::FSUB, VT))
-      break;
     SDValue X = Op.getOperand(0), Y = Op.getOperand(1);
+    assert((DAG.getTarget().Options.NoSignedZerosFPMath ||
+            Flags.hasNoSignedZeros()) &&
+           "Expected NSZ fp-flag");
 
     // fold (fneg (fadd X, Y)) -> (fsub (fneg X), Y)
-    NegatibleCost CostX = NegatibleCost::Expensive;
-    SDValue NegX =
-        getNegatedExpression(X, DAG, LegalOps, OptForSize, CostX, Depth);
-    // fold (fneg (fadd X, Y)) -> (fsub (fneg Y), X)
-    NegatibleCost CostY = NegatibleCost::Expensive;
-    SDValue NegY =
-        getNegatedExpression(Y, DAG, LegalOps, OptForSize, CostY, Depth);
-
-    // Negate the X if its cost is less or equal than Y.
-    if (NegX && (CostX <= CostY)) {
-      Cost = CostX;
-      return DAG.getNode(ISD::FSUB, DL, VT, NegX, Y, Flags);
-    }
+    NegatibleCost CostX = getNegatibleCost(X, DAG, LegalOps, OptForSize, Depth);
+    if (CostX != NegatibleCost::Expensive)
+      return DAG.getNode(ISD::FSUB, DL, VT,
+                         negateExpression(X, DAG, LegalOps, OptForSize, Depth),
+                         Y, Flags);
 
-    // Negate the Y if it is not expensive.
-    if (NegY) {
-      Cost = CostY;
-      return DAG.getNode(ISD::FSUB, DL, VT, NegY, X, Flags);
-    }
-    break;
+    // fold (fneg (fadd X, Y)) -> (fsub (fneg Y), X)
+    return DAG.getNode(ISD::FSUB, DL, VT,
+                       negateExpression(Y, DAG, LegalOps, OptForSize, Depth), X,
+                       Flags);
   }
   case ISD::FSUB: {
-    // We can't turn -(A-B) into B-A when we honor signed zeros.
-    if (!Options.NoSignedZerosFPMath && !Flags.hasNoSignedZeros())
-      break;
-
     SDValue X = Op.getOperand(0), Y = Op.getOperand(1);
     // fold (fneg (fsub 0, Y)) -> Y
     if (ConstantFPSDNode *C = isConstOrConstSplatFP(X, /*AllowUndefs*/ true))
-      if (C->isZero()) {
-        Cost = NegatibleCost::Cheaper;
+      if (C->isZero())
         return Y;
-      }
 
     // fold (fneg (fsub X, Y)) -> (fsub Y, X)
-    Cost = NegatibleCost::Neutral;
     return DAG.getNode(ISD::FSUB, DL, VT, Y, X, Flags);
   }
   case ISD::FMUL:
   case ISD::FDIV: {
     SDValue X = Op.getOperand(0), Y = Op.getOperand(1);
-
     // fold (fneg (fmul X, Y)) -> (fmul (fneg X), Y)
-    NegatibleCost CostX = NegatibleCost::Expensive;
-    SDValue NegX =
-        getNegatedExpression(X, DAG, LegalOps, OptForSize, CostX, Depth);
-    // fold (fneg (fmul X, Y)) -> (fmul X, (fneg Y))
-    NegatibleCost CostY = NegatibleCost::Expensive;
-    SDValue NegY =
-        getNegatedExpression(Y, DAG, LegalOps, OptForSize, CostY, Depth);
-
-    // Negate the X if its cost is less or equal than Y.
-    if (NegX && (CostX <= CostY)) {
-      Cost = CostX;
-      return DAG.getNode(Opcode, DL, VT, NegX, Y, Flags);
-    }
-
-    // Ignore X * 2.0 because that is expected to be canonicalized to X + X.
-    if (auto *C = isConstOrConstSplatFP(Op.getOperand(1)))
-      if (C->isExactlyValue(2.0) && Op.getOpcode() == ISD::FMUL)
-        break;
+    NegatibleCost CostX = getNegatibleCost(X, DAG, LegalOps, OptForSize, Depth);
+    if (CostX != NegatibleCost::Expensive)
+      return DAG.getNode(Opcode, DL, VT,
+                         negateExpression(X, DAG, LegalOps, OptForSize, Depth),
+                         Y, Flags);
 
-    // Negate the Y if it is not expensive.
-    if (NegY) {
-      Cost = CostY;
-      return DAG.getNode(Opcode, DL, VT, X, NegY, Flags);
-    }
-    break;
+    // fold (fneg (fmul X, Y)) -> (fmul X, (fneg Y))
+    return DAG.getNode(Opcode, DL, VT, X,
+                       negateExpression(Y, DAG, LegalOps, OptForSize, Depth),
+                       Flags);
   }
   case ISD::FMA:
   case ISD::FMAD: {
-    if (!Options.NoSignedZerosFPMath && !Flags.hasNoSignedZeros())
-      break;
+    assert((DAG.getTarget().Options.NoSignedZerosFPMath ||
+            Flags.hasNoSignedZeros()) &&
+           "Expected NSZ fp-flag");
 
     SDValue X = Op.getOperand(0), Y = Op.getOperand(1), Z = Op.getOperand(2);
-    NegatibleCost CostZ = NegatibleCost::Expensive;
-    SDValue NegZ =
-        getNegatedExpression(Z, DAG, LegalOps, OptForSize, CostZ, Depth);
-    // Give up if fail to negate the Z.
-    if (!NegZ)
-      break;
-
-    // fold (fneg (fma X, Y, Z)) -> (fma (fneg X), Y, (fneg Z))
-    NegatibleCost CostX = NegatibleCost::Expensive;
-    SDValue NegX =
-        getNegatedExpression(X, DAG, LegalOps, OptForSize, CostX, Depth);
-    // fold (fneg (fma X, Y, Z)) -> (fma X, (fneg Y), (fneg Z))
-    NegatibleCost CostY = NegatibleCost::Expensive;
-    SDValue NegY =
-        getNegatedExpression(Y, DAG, LegalOps, OptForSize, CostY, Depth);
-
-    // Negate the X if its cost is less or equal than Y.
-    if (NegX && (CostX <= CostY)) {
-      Cost = std::min(CostX, CostZ);
+    SDValue NegZ = negateExpression(Z, DAG, LegalOps, OptForSize, Depth);
+    NegatibleCost CostX = getNegatibleCost(X, DAG, LegalOps, OptForSize, Depth);
+    NegatibleCost CostY = getNegatibleCost(Y, DAG, LegalOps, OptForSize, Depth);
+    if (CostX <= CostY) {
+      // fold (fneg (fma X, Y, Z)) -> (fma (fneg X), Y, (fneg Z))
+      SDValue NegX = negateExpression(X, DAG, LegalOps, OptForSize, Depth);
       return DAG.getNode(Opcode, DL, VT, NegX, Y, NegZ, Flags);
     }
 
-    // Negate the Y if it is not expensive.
-    if (NegY) {
-      Cost = std::min(CostY, CostZ);
-      return DAG.getNode(Opcode, DL, VT, X, NegY, NegZ, Flags);
-    }
-    break;
+    // fold (fneg (fma X, Y, Z)) -> (fma X, (fneg Y), (fneg Z))
+    SDValue NegY = negateExpression(Y, DAG, LegalOps, OptForSize, Depth);
+    return DAG.getNode(Opcode, DL, VT, X, NegY, NegZ, Flags);
   }
 
   case ISD::FP_EXTEND:
   case ISD::FSIN:
-    if (SDValue NegV = getNegatedExpression(Op.getOperand(0), DAG, LegalOps,
-                                            OptForSize, Cost, Depth))
-      return DAG.getNode(Opcode, DL, VT, NegV);
-    break;
+    return DAG.getNode(
+        Opcode, DL, VT,
+        negateExpression(Op.getOperand(0), DAG, LegalOps, OptForSize, Depth));
   case ISD::FP_ROUND:
-    if (SDValue NegV = getNegatedExpression(Op.getOperand(0), DAG, LegalOps,
-                                            OptForSize, Cost, Depth))
-      return DAG.getNode(ISD::FP_ROUND, DL, VT, NegV, Op.getOperand(1));
-    break;
+    return DAG.getNode(
+        ISD::FP_ROUND, DL, VT,
+        negateExpression(Op.getOperand(0), DAG, LegalOps, OptForSize, Depth),
+        Op.getOperand(1));
   }
 
-  return SDValue();
+  llvm_unreachable("Unknown code");
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
index 2b1c0594b8d1..014bb5a1b4ee 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
@@ -756,24 +756,24 @@ bool AMDGPUTargetLowering::isSDNodeAlwaysUniform(const SDNode * N) const {
   }
 }
 
-SDValue AMDGPUTargetLowering::getNegatedExpression(
-    SDValue Op, SelectionDAG &DAG, bool LegalOperations, bool ForCodeSize,
-    NegatibleCost &Cost, unsigned Depth) const {
-
+TargetLowering::NegatibleCost
+AMDGPUTargetLowering::getNegatibleCost(SDValue Op, SelectionDAG &DAG,
+                                       bool LegalOperations, bool ForCodeSize,
+                                       unsigned Depth) const {
   switch (Op.getOpcode()) {
   case ISD::FMA:
   case ISD::FMAD: {
     // Negating a fma is not free if it has users without source mods.
     if (!allUsesHaveSourceMods(Op.getNode()))
-      return SDValue();
+      return NegatibleCost::Expensive;
     break;
   }
   default:
     break;
   }
 
-  return TargetLowering::getNegatedExpression(Op, DAG, LegalOperations,
-                                              ForCodeSize, Cost, Depth);
+  return TargetLowering::getNegatibleCost(Op, DAG, LegalOperations, ForCodeSize,
+                                          Depth);
 }
 
 //===---------------------------------------------------------------------===//

diff  --git a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.h b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.h
index df4faf700f66..d81b447d1a19 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.h
+++ b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.h
@@ -170,10 +170,9 @@ class AMDGPUTargetLowering : public TargetLowering {
   bool isZExtFree(EVT Src, EVT Dest) const override;
   bool isZExtFree(SDValue Val, EVT VT2) const override;
 
-  SDValue getNegatedExpression(SDValue Op, SelectionDAG &DAG,
-                               bool LegalOperations, bool ForCodeSize,
-                               NegatibleCost &Cost,
-                               unsigned Depth) const override;
+  NegatibleCost getNegatibleCost(SDValue Op, SelectionDAG &DAG,
+                                 bool LegalOperations, bool ForCodeSize,
+                                 unsigned Depth) const override;
 
   bool isNarrowingProfitable(EVT VT1, EVT VT2) const override;
 

diff  --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 846265e85d36..09a4f70da4fa 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -43995,16 +43995,60 @@ static SDValue combineFneg(SDNode *N, SelectionDAG &DAG,
   return SDValue();
 }
 
-SDValue X86TargetLowering::getNegatedExpression(SDValue Op, SelectionDAG &DAG,
-                                                bool LegalOperations,
-                                                bool ForCodeSize,
-                                                NegatibleCost &Cost,
-                                                unsigned Depth) const {
+TargetLowering::NegatibleCost
+X86TargetLowering::getNegatibleCost(SDValue Op, SelectionDAG &DAG,
+                                    bool LegalOperations, bool ForCodeSize,
+                                    unsigned Depth) const {
   // fneg patterns are removable even if they have multiple uses.
-  if (SDValue Arg = isFNEG(DAG, Op.getNode(), Depth)) {
-    Cost = NegatibleCost::Cheaper;
-    return DAG.getBitcast(Op.getValueType(), Arg);
+  if (isFNEG(DAG, Op.getNode(), Depth))
+    return NegatibleCost::Cheaper;
+
+  // Don't recurse exponentially.
+  if (Depth > SelectionDAG::MaxRecursionDepth)
+    return NegatibleCost::Expensive;
+
+  EVT VT = Op.getValueType();
+  EVT SVT = VT.getScalarType();
+  switch (Op.getOpcode()) {
+  case ISD::FMA:
+  case X86ISD::FMSUB:
+  case X86ISD::FNMADD:
+  case X86ISD::FNMSUB:
+  case X86ISD::FMADD_RND:
+  case X86ISD::FMSUB_RND:
+  case X86ISD::FNMADD_RND:
+  case X86ISD::FNMSUB_RND: {
+    if (!Op.hasOneUse() || !Subtarget.hasAnyFMA() || !isTypeLegal(VT) ||
+        !(SVT == MVT::f32 || SVT == MVT::f64) ||
+        !isOperationLegal(ISD::FMA, VT))
+      break;
+
+    // This is always negatible for free but we might be able to remove some
+    // extra operand negations as well.
+    for (int i = 0; i != 3; ++i) {
+      NegatibleCost V = getNegatibleCost(Op.getOperand(i), DAG, LegalOperations,
+                                         ForCodeSize, Depth + 1);
+      if (V == NegatibleCost::Cheaper)
+        return V;
+    }
+    return NegatibleCost::Neutral;
   }
+  case X86ISD::FRCP:
+    return getNegatibleCost(Op.getOperand(0), DAG, LegalOperations, ForCodeSize,
+                            Depth + 1);
+  }
+
+  return TargetLowering::getNegatibleCost(Op, DAG, LegalOperations, ForCodeSize,
+                                          Depth);
+}
+
+SDValue X86TargetLowering::negateExpression(SDValue Op, SelectionDAG &DAG,
+                                            bool LegalOperations,
+                                            bool ForCodeSize,
+                                            unsigned Depth) const {
+  // fneg patterns are removable even if they have multiple uses.
+  if (SDValue Arg = isFNEG(DAG, Op.getNode(), Depth))
+    return DAG.getBitcast(Op.getValueType(), Arg);
 
   EVT VT = Op.getValueType();
   EVT SVT = VT.getScalarType();
@@ -44035,9 +44079,6 @@ SDValue X86TargetLowering::getNegatedExpression(SDValue Op, SelectionDAG &DAG,
     bool NegC = !!NewOps[2];
     unsigned NewOpc = negateFMAOpcode(Opc, NegA != NegB, NegC, true);
 
-    Cost = (NegA || NegB || NegC) ? NegatibleCost::Cheaper
-                                  : NegatibleCost::Neutral;
-
     // Fill in the non-negated ops with the original values.
     for (int i = 0, e = Op.getNumOperands(); i != e; ++i)
       if (!NewOps[i])
@@ -44045,15 +44086,14 @@ SDValue X86TargetLowering::getNegatedExpression(SDValue Op, SelectionDAG &DAG,
     return DAG.getNode(NewOpc, SDLoc(Op), VT, NewOps);
   }
   case X86ISD::FRCP:
-    if (SDValue NegOp0 =
-            getNegatedExpression(Op.getOperand(0), DAG, LegalOperations,
-                                 ForCodeSize, Cost, Depth + 1))
-      return DAG.getNode(Opc, SDLoc(Op), VT, NegOp0);
+    return DAG.getNode(Opc, SDLoc(Op), VT,
+                       negateExpression(Op.getOperand(0), DAG, LegalOperations,
+                                        ForCodeSize, Depth + 1));
     break;
   }
 
-  return TargetLowering::getNegatedExpression(Op, DAG, LegalOperations,
-                                              ForCodeSize, Cost, Depth);
+  return TargetLowering::negateExpression(Op, DAG, LegalOperations, ForCodeSize,
+                                          Depth);
 }
 
 static SDValue lowerX86FPLogicOp(SDNode *N, SelectionDAG &DAG,

diff  --git a/llvm/lib/Target/X86/X86ISelLowering.h b/llvm/lib/Target/X86/X86ISelLowering.h
index c24551c050ce..50daba821ae9 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.h
+++ b/llvm/lib/Target/X86/X86ISelLowering.h
@@ -814,13 +814,17 @@ namespace llvm {
     /// and some i16 instructions are slow.
     bool IsDesirableToPromoteOp(SDValue Op, EVT &PVT) const override;
 
-    /// Return the newly negated expression if the cost is not expensive and
-    /// set the cost in \p Cost to indicate that if it is cheaper or neutral to
-    /// do the negation.
-    SDValue getNegatedExpression(SDValue Op, SelectionDAG &DAG,
-                                 bool LegalOperations, bool ForCodeSize,
-                                 NegatibleCost &Cost,
-                                 unsigned Depth) const override;
+    /// Returns whether computing the negated form of the specified expression
+    /// is more expensive, the same cost or cheaper.
+    NegatibleCost getNegatibleCost(SDValue Op, SelectionDAG &DAG,
+                                   bool LegalOperations, bool ForCodeSize,
+                                   unsigned Depth) const override;
+
+    /// If getNegatibleCost returns Neutral/Cheaper, return the newly negated
+    /// expression.
+    SDValue negateExpression(SDValue Op, SelectionDAG &DAG,
+                             bool LegalOperations, bool ForCodeSize,
+                             unsigned Depth) const override;
 
     MachineBasicBlock *
     EmitInstrWithCustomInserter(MachineInstr &MI,

diff  --git a/llvm/test/CodeGen/X86/neg_fp.ll b/llvm/test/CodeGen/X86/neg_fp.ll
index e07157c952f7..3c04aafcea4d 100644
--- a/llvm/test/CodeGen/X86/neg_fp.ll
+++ b/llvm/test/CodeGen/X86/neg_fp.ll
@@ -54,30 +54,3 @@ define double @negation_propagation(double* %arg, double %arg1, double %arg2) no
   %t18 = fadd double %t16, %t7
   ret double %t18
 }
-
-; This would crash because the negated expression for %sub4
-; creates a new use of %sub1 and that alters the negated cost
-
-define float @fdiv_extra_use_changes_cost(float %a0, float %a1, float %a2) nounwind {
-; CHECK-LABEL: fdiv_extra_use_changes_cost:
-; CHECK:       # %bb.0:
-; CHECK-NEXT:    pushl %eax
-; CHECK-NEXT:    movss {{.*#+}} xmm0 = mem[0],zero,zero,zero
-; CHECK-NEXT:    movss {{.*#+}} xmm1 = mem[0],zero,zero,zero
-; CHECK-NEXT:    subss {{[0-9]+}}(%esp), %xmm1
-; CHECK-NEXT:    movaps %xmm1, %xmm2
-; CHECK-NEXT:    mulss %xmm0, %xmm2
-; CHECK-NEXT:    subss %xmm1, %xmm0
-; CHECK-NEXT:    divss %xmm2, %xmm0
-; CHECK-NEXT:    movss %xmm0, (%esp)
-; CHECK-NEXT:    flds (%esp)
-; CHECK-NEXT:    popl %eax
-; CHECK-NEXT:    retl
-  %sub1 = fsub fast float %a0, %a1
-  %mul2 = fmul fast float %sub1, %a2
-  %neg = fneg fast float %a0
-  %add3 = fadd fast float %a1, %neg
-  %sub4 = fadd fast float %add3, %a2
-  %div5 = fdiv fast float %sub4, %mul2
-  ret float %div5
-}


        


More information about the llvm-commits mailing list