[llvm] 2b59e9f - [DAGCombine] Remove the getNegatibleCost to avoid the out of sync with getNegatedExpression

QingShan Zhang via llvm-commits llvm-commits at lists.llvm.org
Tue May 19 19:12:26 PDT 2020


Author: QingShan Zhang
Date: 2020-05-20T02:12:16Z
New Revision: 2b59e9f1bdd8d0f8a315f9ddba2cdc87a3e682fb

URL: https://github.com/llvm/llvm-project/commit/2b59e9f1bdd8d0f8a315f9ddba2cdc87a3e682fb
DIFF: https://github.com/llvm/llvm-project/commit/2b59e9f1bdd8d0f8a315f9ddba2cdc87a3e682fb.diff

LOG: [DAGCombine] Remove the getNegatibleCost to avoid the out of sync with getNegatedExpression

We have the getNegatibleCost/getNegatedExpression to evaluate the cost and negate the expression.
However, during negating the expression, the cost might change as we are changing the DAG,
and then, hit the assertion if we negated the wrong expression as the cost is not trustful anymore.

This patch is target to remove the getNegatibleCost to avoid the out of sync with getNegatedExpression,
and check the cost during negating the expression. It also reduce the duplicated code between
getNegatibleCost and getNegatedExpression. And fix the crash for the test in D76638

Reviewed By: RKSimon, spatel

Differential Revision: https://reviews.llvm.org/D77319

Added: 
    

Modified: 
    llvm/include/llvm/CodeGen/TargetLowering.h
    llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
    llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
    llvm/lib/Target/AMDGPU/AMDGPUISelLowering.h
    llvm/lib/Target/X86/X86ISelLowering.cpp
    llvm/lib/Target/X86/X86ISelLowering.h
    llvm/test/CodeGen/X86/neg_fp.ll

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index 58da3fa1c71b..2cab2572335c 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -3545,37 +3545,27 @@ class TargetLowering : public TargetLoweringBase {
     llvm_unreachable("Not Implemented");
   }
 
-  /// Returns whether computing the negated form of the specified expression is
-  /// more expensive, the same cost or cheaper.
-  virtual NegatibleCost getNegatibleCost(SDValue Op, SelectionDAG &DAG,
-                                         bool LegalOperations, bool ForCodeSize,
-                                         unsigned Depth = 0) const;
-
-  /// If getNegatibleCost returns Neutral/Cheaper, return the newly negated
-  /// expression.
-  virtual SDValue negateExpression(SDValue Op, SelectionDAG &DAG, bool LegalOps,
-                                   bool OptForSize, unsigned Depth = 0) const;
-
   /// Return the newly negated expression if the cost is not expensive and
   /// set the cost in \p Cost to indicate that if it is cheaper or neutral to
   /// do the negation.
-  SDValue getNegatedExpression(SDValue Op, SelectionDAG &DAG, bool LegalOps,
-                               bool OptForSize, NegatibleCost &Cost,
-                               unsigned Depth = 0) const {
-    Cost = getNegatibleCost(Op, DAG, LegalOps, OptForSize, Depth);
-    if (Cost != NegatibleCost::Expensive)
-      return negateExpression(Op, DAG, LegalOps, OptForSize, Depth);
-    return SDValue();
-  }
+  virtual SDValue getNegatedExpression(SDValue Op, SelectionDAG &DAG,
+                                       bool LegalOps, bool OptForSize,
+                                       NegatibleCost &Cost,
+                                       unsigned Depth = 0) const;
 
   /// This is the helper function to return the newly negated expression only
   /// when the cost is cheaper.
   SDValue getCheaperNegatedExpression(SDValue Op, SelectionDAG &DAG,
                                       bool LegalOps, bool OptForSize,
                                       unsigned Depth = 0) const {
-    if (getNegatibleCost(Op, DAG, LegalOps, OptForSize, Depth) ==
-        NegatibleCost::Cheaper)
-      return negateExpression(Op, DAG, LegalOps, OptForSize, Depth);
+    NegatibleCost Cost = NegatibleCost::Expensive;
+    SDValue Neg =
+        getNegatedExpression(Op, DAG, LegalOps, OptForSize, Cost, Depth);
+    if (Neg && Cost == NegatibleCost::Cheaper)
+      return Neg;
+    // Remove the new created node to avoid the side effect to the DAG.
+    if (Neg && Neg.getNode()->use_empty())
+      DAG.RemoveDeadNode(Neg.getNode());
     return SDValue();
   }
 

diff  --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index 00a7a87a5530..a8890788ccbf 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -5562,165 +5562,79 @@ verifyReturnAddressArgumentIsConstant(SDValue Op, SelectionDAG &DAG) const {
   return false;
 }
 
-TargetLowering::NegatibleCost
-TargetLowering::getNegatibleCost(SDValue Op, SelectionDAG &DAG,
-                                 bool LegalOperations, bool ForCodeSize,
-                                 unsigned Depth) const {
+SDValue TargetLowering::getNegatedExpression(SDValue Op, SelectionDAG &DAG,
+                                             bool LegalOps, bool OptForSize,
+                                             NegatibleCost &Cost,
+                                             unsigned Depth) const {
   // fneg is removable even if it has multiple uses.
-  if (Op.getOpcode() == ISD::FNEG)
-    return NegatibleCost::Cheaper;
+  if (Op.getOpcode() == ISD::FNEG) {
+    Cost = NegatibleCost::Cheaper;
+    return Op.getOperand(0);
+  }
 
-  // Don't allow anything with multiple uses unless we know it is free.
-  EVT VT = Op.getValueType();
+  // Don't recurse exponentially.
+  if (Depth > SelectionDAG::MaxRecursionDepth)
+    return SDValue();
+
+  // Pre-increment recursion depth for use in recursive calls.
+  ++Depth;
   const SDNodeFlags Flags = Op->getFlags();
   const TargetOptions &Options = DAG.getTarget().Options;
-  if (!Op.hasOneUse()) {
-    bool IsFreeExtend = Op.getOpcode() == ISD::FP_EXTEND &&
-                        isFPExtFree(VT, Op.getOperand(0).getValueType());
-
-    // If we already have the use of the negated floating constant, it is free
-    // to negate it even it has multiple uses.
-    bool IsFreeConstant =
-        Op.getOpcode() == ISD::ConstantFP &&
-        !negateExpression(Op, DAG, LegalOperations, ForCodeSize).use_empty();
+  EVT VT = Op.getValueType();
+  unsigned Opcode = Op.getOpcode();
 
-    if (!IsFreeExtend && !IsFreeConstant)
-      return NegatibleCost::Expensive;
+  // Don't allow anything with multiple uses unless we know it is free.
+  if (!Op.hasOneUse() && Opcode != ISD::ConstantFP) {
+    bool IsFreeExtend = Opcode == ISD::FP_EXTEND &&
+                        isFPExtFree(VT, Op.getOperand(0).getValueType());
+    if (!IsFreeExtend)
+      return SDValue();
   }
 
-  // Don't recurse exponentially.
-  if (Depth > SelectionDAG::MaxRecursionDepth)
-    return NegatibleCost::Expensive;
+  SDLoc DL(Op);
 
-  switch (Op.getOpcode()) {
+  switch (Opcode) {
   case ISD::ConstantFP: {
-    if (!LegalOperations)
-      return NegatibleCost::Neutral;
-
     // Don't invert constant FP values after legalization unless the target says
     // the negated constant is legal.
-    if (isOperationLegal(ISD::ConstantFP, VT) ||
+    bool IsOpLegal =
+        isOperationLegal(ISD::ConstantFP, VT) ||
         isFPImmLegal(neg(cast<ConstantFPSDNode>(Op)->getValueAPF()), VT,
-                     ForCodeSize))
-      return NegatibleCost::Neutral;
-    break;
+                     OptForSize);
+
+    if (LegalOps && !IsOpLegal)
+      break;
+
+    APFloat V = cast<ConstantFPSDNode>(Op)->getValueAPF();
+    V.changeSign();
+    SDValue CFP = DAG.getConstantFP(V, DL, VT);
+
+    // If we already have the use of the negated floating constant, it is free
+    // to negate it even it has multiple uses.
+    if (!Op.hasOneUse() && CFP.use_empty())
+      break;
+    Cost = NegatibleCost::Neutral;
+    return CFP;
   }
   case ISD::BUILD_VECTOR: {
     // Only permit BUILD_VECTOR of constants.
     if (llvm::any_of(Op->op_values(), [&](SDValue N) {
           return !N.isUndef() && !isa<ConstantFPSDNode>(N);
         }))
-      return NegatibleCost::Expensive;
-    if (!LegalOperations)
-      return NegatibleCost::Neutral;
-    if (isOperationLegal(ISD::ConstantFP, VT) &&
-        isOperationLegal(ISD::BUILD_VECTOR, VT))
-      return NegatibleCost::Neutral;
-    if (llvm::all_of(Op->op_values(), [&](SDValue N) {
+      break;
+
+    bool IsOpLegal =
+        (isOperationLegal(ISD::ConstantFP, VT) &&
+         isOperationLegal(ISD::BUILD_VECTOR, VT)) ||
+        llvm::all_of(Op->op_values(), [&](SDValue N) {
           return N.isUndef() ||
                  isFPImmLegal(neg(cast<ConstantFPSDNode>(N)->getValueAPF()), VT,
-                              ForCodeSize);
-        }))
-      return NegatibleCost::Neutral;
-    break;
-  }
-  case ISD::FADD: {
-    if (!Options.NoSignedZerosFPMath && !Flags.hasNoSignedZeros())
-      return NegatibleCost::Expensive;
-
-    // After operation legalization, it might not be legal to create new FSUBs.
-    if (LegalOperations && !isOperationLegalOrCustom(ISD::FSUB, VT))
-      return NegatibleCost::Expensive;
-
-    // fold (fneg (fadd A, B)) -> (fsub (fneg A), B)
-    NegatibleCost V0 = getNegatibleCost(Op.getOperand(0), DAG, LegalOperations,
-                                        ForCodeSize, Depth + 1);
-    if (V0 != NegatibleCost::Expensive)
-      return V0;
-    // fold (fneg (fadd A, B)) -> (fsub (fneg B), A)
-    return getNegatibleCost(Op.getOperand(1), DAG, LegalOperations, ForCodeSize,
-                            Depth + 1);
-  }
-  case ISD::FSUB:
-    // We can't turn -(A-B) into B-A when we honor signed zeros.
-    if (!Options.NoSignedZerosFPMath && !Flags.hasNoSignedZeros())
-      return NegatibleCost::Expensive;
-
-    // fold (fneg (fsub A, B)) -> (fsub B, A)
-    return NegatibleCost::Neutral;
-  case ISD::FMUL:
-  case ISD::FDIV: {
-    // fold (fneg (fmul X, Y)) -> (fmul (fneg X), Y) or (fmul X, (fneg Y))
-    NegatibleCost V0 = getNegatibleCost(Op.getOperand(0), DAG, LegalOperations,
-                                        ForCodeSize, Depth + 1);
-    if (V0 != NegatibleCost::Expensive)
-      return V0;
-
-    // Ignore X * 2.0 because that is expected to be canonicalized to X + X.
-    if (auto *C = isConstOrConstSplatFP(Op.getOperand(1)))
-      if (C->isExactlyValue(2.0) && Op.getOpcode() == ISD::FMUL)
-        return NegatibleCost::Expensive;
-
-    return getNegatibleCost(Op.getOperand(1), DAG, LegalOperations, ForCodeSize,
-                            Depth + 1);
-  }
-  case ISD::FMA:
-  case ISD::FMAD: {
-    if (!Options.NoSignedZerosFPMath && !Flags.hasNoSignedZeros())
-      return NegatibleCost::Expensive;
-
-    // fold (fneg (fma X, Y, Z)) -> (fma (fneg X), Y, (fneg Z))
-    // fold (fneg (fma X, Y, Z)) -> (fma X, (fneg Y), (fneg Z))
-    NegatibleCost V2 = getNegatibleCost(Op.getOperand(2), DAG, LegalOperations,
-                                        ForCodeSize, Depth + 1);
-    if (NegatibleCost::Expensive == V2)
-      return NegatibleCost::Expensive;
-
-    // One of Op0/Op1 must be cheaply negatible, then select the cheapest.
-    NegatibleCost V0 = getNegatibleCost(Op.getOperand(0), DAG, LegalOperations,
-                                        ForCodeSize, Depth + 1);
-    NegatibleCost V1 = getNegatibleCost(Op.getOperand(1), DAG, LegalOperations,
-                                        ForCodeSize, Depth + 1);
-    NegatibleCost V01 = std::min(V0, V1);
-    if (V01 == NegatibleCost::Expensive)
-      return NegatibleCost::Expensive;
-    return std::min(V01, V2);
-  }
+                              OptForSize);
+        });
 
-  case ISD::FP_EXTEND:
-  case ISD::FP_ROUND:
-  case ISD::FSIN:
-    return getNegatibleCost(Op.getOperand(0), DAG, LegalOperations, ForCodeSize,
-                            Depth + 1);
-  }
-
-  return NegatibleCost::Expensive;
-}
-
-SDValue TargetLowering::negateExpression(SDValue Op, SelectionDAG &DAG,
-                                         bool LegalOps, bool OptForSize,
-                                         unsigned Depth) const {
-  // fneg is removable even if it has multiple uses.
-  if (Op.getOpcode() == ISD::FNEG)
-    return Op.getOperand(0);
-
-  assert(Depth <= SelectionDAG::MaxRecursionDepth &&
-         "negateExpression doesn't match getNegatibleCost");
-
-  // Pre-increment recursion depth for use in recursive calls.
-  ++Depth;
-  const SDNodeFlags Flags = Op->getFlags();
-  EVT VT = Op.getValueType();
-  unsigned Opcode = Op.getOpcode();
-  SDLoc DL(Op);
+    if (LegalOps && !IsOpLegal)
+      break;
 
-  switch (Opcode) {
-  case ISD::ConstantFP: {
-    APFloat V = cast<ConstantFPSDNode>(Op)->getValueAPF();
-    V.changeSign();
-    return DAG.getConstantFP(V, DL, VT);
-  }
-  case ISD::BUILD_VECTOR: {
     SmallVector<SDValue, 4> Ops;
     for (SDValue C : Op->op_values()) {
       if (C.isUndef()) {
@@ -5731,85 +5645,138 @@ SDValue TargetLowering::negateExpression(SDValue Op, SelectionDAG &DAG,
       V.changeSign();
       Ops.push_back(DAG.getConstantFP(V, DL, C.getValueType()));
     }
+    Cost = NegatibleCost::Neutral;
     return DAG.getBuildVector(VT, DL, Ops);
   }
   case ISD::FADD: {
+    if (!Options.NoSignedZerosFPMath && !Flags.hasNoSignedZeros())
+      break;
+
+    // After operation legalization, it might not be legal to create new FSUBs.
+    if (LegalOps && !isOperationLegalOrCustom(ISD::FSUB, VT))
+      break;
     SDValue X = Op.getOperand(0), Y = Op.getOperand(1);
-    assert((DAG.getTarget().Options.NoSignedZerosFPMath ||
-            Flags.hasNoSignedZeros()) &&
-           "Expected NSZ fp-flag");
 
     // fold (fneg (fadd X, Y)) -> (fsub (fneg X), Y)
-    NegatibleCost CostX = getNegatibleCost(X, DAG, LegalOps, OptForSize, Depth);
-    if (CostX != NegatibleCost::Expensive)
-      return DAG.getNode(ISD::FSUB, DL, VT,
-                         negateExpression(X, DAG, LegalOps, OptForSize, Depth),
-                         Y, Flags);
-
+    NegatibleCost CostX = NegatibleCost::Expensive;
+    SDValue NegX =
+        getNegatedExpression(X, DAG, LegalOps, OptForSize, CostX, Depth);
     // fold (fneg (fadd X, Y)) -> (fsub (fneg Y), X)
-    return DAG.getNode(ISD::FSUB, DL, VT,
-                       negateExpression(Y, DAG, LegalOps, OptForSize, Depth), X,
-                       Flags);
+    NegatibleCost CostY = NegatibleCost::Expensive;
+    SDValue NegY =
+        getNegatedExpression(Y, DAG, LegalOps, OptForSize, CostY, Depth);
+
+    // Negate the X if its cost is less or equal than Y.
+    if (NegX && (CostX <= CostY)) {
+      Cost = CostX;
+      return DAG.getNode(ISD::FSUB, DL, VT, NegX, Y, Flags);
+    }
+
+    // Negate the Y if it is not expensive.
+    if (NegY) {
+      Cost = CostY;
+      return DAG.getNode(ISD::FSUB, DL, VT, NegY, X, Flags);
+    }
+    break;
   }
   case ISD::FSUB: {
+    // We can't turn -(A-B) into B-A when we honor signed zeros.
+    if (!Options.NoSignedZerosFPMath && !Flags.hasNoSignedZeros())
+      break;
+
     SDValue X = Op.getOperand(0), Y = Op.getOperand(1);
     // fold (fneg (fsub 0, Y)) -> Y
     if (ConstantFPSDNode *C = isConstOrConstSplatFP(X, /*AllowUndefs*/ true))
-      if (C->isZero())
+      if (C->isZero()) {
+        Cost = NegatibleCost::Cheaper;
         return Y;
+      }
 
     // fold (fneg (fsub X, Y)) -> (fsub Y, X)
+    Cost = NegatibleCost::Neutral;
     return DAG.getNode(ISD::FSUB, DL, VT, Y, X, Flags);
   }
   case ISD::FMUL:
   case ISD::FDIV: {
     SDValue X = Op.getOperand(0), Y = Op.getOperand(1);
-    // fold (fneg (fmul X, Y)) -> (fmul (fneg X), Y)
-    NegatibleCost CostX = getNegatibleCost(X, DAG, LegalOps, OptForSize, Depth);
-    if (CostX != NegatibleCost::Expensive)
-      return DAG.getNode(Opcode, DL, VT,
-                         negateExpression(X, DAG, LegalOps, OptForSize, Depth),
-                         Y, Flags);
 
+    // fold (fneg (fmul X, Y)) -> (fmul (fneg X), Y)
+    NegatibleCost CostX = NegatibleCost::Expensive;
+    SDValue NegX =
+        getNegatedExpression(X, DAG, LegalOps, OptForSize, CostX, Depth);
     // fold (fneg (fmul X, Y)) -> (fmul X, (fneg Y))
-    return DAG.getNode(Opcode, DL, VT, X,
-                       negateExpression(Y, DAG, LegalOps, OptForSize, Depth),
-                       Flags);
+    NegatibleCost CostY = NegatibleCost::Expensive;
+    SDValue NegY =
+        getNegatedExpression(Y, DAG, LegalOps, OptForSize, CostY, Depth);
+
+    // Negate the X if its cost is less or equal than Y.
+    if (NegX && (CostX <= CostY)) {
+      Cost = CostX;
+      return DAG.getNode(Opcode, DL, VT, NegX, Y, Flags);
+    }
+
+    // Ignore X * 2.0 because that is expected to be canonicalized to X + X.
+    if (auto *C = isConstOrConstSplatFP(Op.getOperand(1)))
+      if (C->isExactlyValue(2.0) && Op.getOpcode() == ISD::FMUL)
+        break;
+
+    // Negate the Y if it is not expensive.
+    if (NegY) {
+      Cost = CostY;
+      return DAG.getNode(Opcode, DL, VT, X, NegY, Flags);
+    }
+    break;
   }
   case ISD::FMA:
   case ISD::FMAD: {
-    assert((DAG.getTarget().Options.NoSignedZerosFPMath ||
-            Flags.hasNoSignedZeros()) &&
-           "Expected NSZ fp-flag");
+    if (!Options.NoSignedZerosFPMath && !Flags.hasNoSignedZeros())
+      break;
 
     SDValue X = Op.getOperand(0), Y = Op.getOperand(1), Z = Op.getOperand(2);
-    SDValue NegZ = negateExpression(Z, DAG, LegalOps, OptForSize, Depth);
-    NegatibleCost CostX = getNegatibleCost(X, DAG, LegalOps, OptForSize, Depth);
-    NegatibleCost CostY = getNegatibleCost(Y, DAG, LegalOps, OptForSize, Depth);
-    if (CostX <= CostY) {
-      // fold (fneg (fma X, Y, Z)) -> (fma (fneg X), Y, (fneg Z))
-      SDValue NegX = negateExpression(X, DAG, LegalOps, OptForSize, Depth);
+    NegatibleCost CostZ = NegatibleCost::Expensive;
+    SDValue NegZ =
+        getNegatedExpression(Z, DAG, LegalOps, OptForSize, CostZ, Depth);
+    // Give up if fail to negate the Z.
+    if (!NegZ)
+      break;
+
+    // fold (fneg (fma X, Y, Z)) -> (fma (fneg X), Y, (fneg Z))
+    NegatibleCost CostX = NegatibleCost::Expensive;
+    SDValue NegX =
+        getNegatedExpression(X, DAG, LegalOps, OptForSize, CostX, Depth);
+    // fold (fneg (fma X, Y, Z)) -> (fma X, (fneg Y), (fneg Z))
+    NegatibleCost CostY = NegatibleCost::Expensive;
+    SDValue NegY =
+        getNegatedExpression(Y, DAG, LegalOps, OptForSize, CostY, Depth);
+
+    // Negate the X if its cost is less or equal than Y.
+    if (NegX && (CostX <= CostY)) {
+      Cost = std::min(CostX, CostZ);
       return DAG.getNode(Opcode, DL, VT, NegX, Y, NegZ, Flags);
     }
 
-    // fold (fneg (fma X, Y, Z)) -> (fma X, (fneg Y), (fneg Z))
-    SDValue NegY = negateExpression(Y, DAG, LegalOps, OptForSize, Depth);
-    return DAG.getNode(Opcode, DL, VT, X, NegY, NegZ, Flags);
+    // Negate the Y if it is not expensive.
+    if (NegY) {
+      Cost = std::min(CostY, CostZ);
+      return DAG.getNode(Opcode, DL, VT, X, NegY, NegZ, Flags);
+    }
+    break;
   }
 
   case ISD::FP_EXTEND:
   case ISD::FSIN:
-    return DAG.getNode(
-        Opcode, DL, VT,
-        negateExpression(Op.getOperand(0), DAG, LegalOps, OptForSize, Depth));
+    if (SDValue NegV = getNegatedExpression(Op.getOperand(0), DAG, LegalOps,
+                                            OptForSize, Cost, Depth))
+      return DAG.getNode(Opcode, DL, VT, NegV);
+    break;
   case ISD::FP_ROUND:
-    return DAG.getNode(
-        ISD::FP_ROUND, DL, VT,
-        negateExpression(Op.getOperand(0), DAG, LegalOps, OptForSize, Depth),
-        Op.getOperand(1));
+    if (SDValue NegV = getNegatedExpression(Op.getOperand(0), DAG, LegalOps,
+                                            OptForSize, Cost, Depth))
+      return DAG.getNode(ISD::FP_ROUND, DL, VT, NegV, Op.getOperand(1));
+    break;
   }
 
-  llvm_unreachable("Unknown code");
+  return SDValue();
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
index b453098b682b..1edc42adf9de 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
@@ -811,24 +811,24 @@ bool AMDGPUTargetLowering::isSDNodeAlwaysUniform(const SDNode * N) const {
   }
 }
 
-TargetLowering::NegatibleCost
-AMDGPUTargetLowering::getNegatibleCost(SDValue Op, SelectionDAG &DAG,
-                                       bool LegalOperations, bool ForCodeSize,
-                                       unsigned Depth) const {
+SDValue AMDGPUTargetLowering::getNegatedExpression(
+    SDValue Op, SelectionDAG &DAG, bool LegalOperations, bool ForCodeSize,
+    NegatibleCost &Cost, unsigned Depth) const {
+
   switch (Op.getOpcode()) {
   case ISD::FMA:
   case ISD::FMAD: {
     // Negating a fma is not free if it has users without source mods.
     if (!allUsesHaveSourceMods(Op.getNode()))
-      return NegatibleCost::Expensive;
+      return SDValue();
     break;
   }
   default:
     break;
   }
 
-  return TargetLowering::getNegatibleCost(Op, DAG, LegalOperations, ForCodeSize,
-                                          Depth);
+  return TargetLowering::getNegatedExpression(Op, DAG, LegalOperations,
+                                              ForCodeSize, Cost, Depth);
 }
 
 //===---------------------------------------------------------------------===//

diff  --git a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.h b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.h
index 137a18968d5a..3934c7591a13 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.h
+++ b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.h
@@ -170,9 +170,10 @@ class AMDGPUTargetLowering : public TargetLowering {
   bool isZExtFree(EVT Src, EVT Dest) const override;
   bool isZExtFree(SDValue Val, EVT VT2) const override;
 
-  NegatibleCost getNegatibleCost(SDValue Op, SelectionDAG &DAG,
-                                 bool LegalOperations, bool ForCodeSize,
-                                 unsigned Depth) const override;
+  SDValue getNegatedExpression(SDValue Op, SelectionDAG &DAG,
+                               bool LegalOperations, bool ForCodeSize,
+                               NegatibleCost &Cost,
+                               unsigned Depth) const override;
 
   bool isNarrowingProfitable(EVT VT1, EVT VT2) const override;
 

diff  --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 08f455b8bf2c..f4a88de7a1d7 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -44067,60 +44067,16 @@ static SDValue combineFneg(SDNode *N, SelectionDAG &DAG,
   return SDValue();
 }
 
-TargetLowering::NegatibleCost
-X86TargetLowering::getNegatibleCost(SDValue Op, SelectionDAG &DAG,
-                                    bool LegalOperations, bool ForCodeSize,
-                                    unsigned Depth) const {
+SDValue X86TargetLowering::getNegatedExpression(SDValue Op, SelectionDAG &DAG,
+                                                bool LegalOperations,
+                                                bool ForCodeSize,
+                                                NegatibleCost &Cost,
+                                                unsigned Depth) const {
   // fneg patterns are removable even if they have multiple uses.
-  if (isFNEG(DAG, Op.getNode(), Depth))
-    return NegatibleCost::Cheaper;
-
-  // Don't recurse exponentially.
-  if (Depth > SelectionDAG::MaxRecursionDepth)
-    return NegatibleCost::Expensive;
-
-  EVT VT = Op.getValueType();
-  EVT SVT = VT.getScalarType();
-  switch (Op.getOpcode()) {
-  case ISD::FMA:
-  case X86ISD::FMSUB:
-  case X86ISD::FNMADD:
-  case X86ISD::FNMSUB:
-  case X86ISD::FMADD_RND:
-  case X86ISD::FMSUB_RND:
-  case X86ISD::FNMADD_RND:
-  case X86ISD::FNMSUB_RND: {
-    if (!Op.hasOneUse() || !Subtarget.hasAnyFMA() || !isTypeLegal(VT) ||
-        !(SVT == MVT::f32 || SVT == MVT::f64) ||
-        !isOperationLegal(ISD::FMA, VT))
-      break;
-
-    // This is always negatible for free but we might be able to remove some
-    // extra operand negations as well.
-    for (int i = 0; i != 3; ++i) {
-      NegatibleCost V = getNegatibleCost(Op.getOperand(i), DAG, LegalOperations,
-                                         ForCodeSize, Depth + 1);
-      if (V == NegatibleCost::Cheaper)
-        return V;
-    }
-    return NegatibleCost::Neutral;
-  }
-  case X86ISD::FRCP:
-    return getNegatibleCost(Op.getOperand(0), DAG, LegalOperations, ForCodeSize,
-                            Depth + 1);
-  }
-
-  return TargetLowering::getNegatibleCost(Op, DAG, LegalOperations, ForCodeSize,
-                                          Depth);
-}
-
-SDValue X86TargetLowering::negateExpression(SDValue Op, SelectionDAG &DAG,
-                                            bool LegalOperations,
-                                            bool ForCodeSize,
-                                            unsigned Depth) const {
-  // fneg patterns are removable even if they have multiple uses.
-  if (SDValue Arg = isFNEG(DAG, Op.getNode(), Depth))
+  if (SDValue Arg = isFNEG(DAG, Op.getNode(), Depth)) {
+    Cost = NegatibleCost::Cheaper;
     return DAG.getBitcast(Op.getValueType(), Arg);
+  }
 
   EVT VT = Op.getValueType();
   EVT SVT = VT.getScalarType();
@@ -44151,6 +44107,9 @@ SDValue X86TargetLowering::negateExpression(SDValue Op, SelectionDAG &DAG,
     bool NegC = !!NewOps[2];
     unsigned NewOpc = negateFMAOpcode(Opc, NegA != NegB, NegC, true);
 
+    Cost = (NegA || NegB || NegC) ? NegatibleCost::Cheaper
+                                  : NegatibleCost::Neutral;
+
     // Fill in the non-negated ops with the original values.
     for (int i = 0, e = Op.getNumOperands(); i != e; ++i)
       if (!NewOps[i])
@@ -44158,14 +44117,15 @@ SDValue X86TargetLowering::negateExpression(SDValue Op, SelectionDAG &DAG,
     return DAG.getNode(NewOpc, SDLoc(Op), VT, NewOps);
   }
   case X86ISD::FRCP:
-    return DAG.getNode(Opc, SDLoc(Op), VT,
-                       negateExpression(Op.getOperand(0), DAG, LegalOperations,
-                                        ForCodeSize, Depth + 1));
+    if (SDValue NegOp0 =
+            getNegatedExpression(Op.getOperand(0), DAG, LegalOperations,
+                                 ForCodeSize, Cost, Depth + 1))
+      return DAG.getNode(Opc, SDLoc(Op), VT, NegOp0);
     break;
   }
 
-  return TargetLowering::negateExpression(Op, DAG, LegalOperations, ForCodeSize,
-                                          Depth);
+  return TargetLowering::getNegatedExpression(Op, DAG, LegalOperations,
+                                              ForCodeSize, Cost, Depth);
 }
 
 static SDValue lowerX86FPLogicOp(SDNode *N, SelectionDAG &DAG,

diff  --git a/llvm/lib/Target/X86/X86ISelLowering.h b/llvm/lib/Target/X86/X86ISelLowering.h
index 06237208e818..1c9042c2ac38 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.h
+++ b/llvm/lib/Target/X86/X86ISelLowering.h
@@ -935,17 +935,13 @@ namespace llvm {
     /// and some i16 instructions are slow.
     bool IsDesirableToPromoteOp(SDValue Op, EVT &PVT) const override;
 
-    /// Returns whether computing the negated form of the specified expression
-    /// is more expensive, the same cost or cheaper.
-    NegatibleCost getNegatibleCost(SDValue Op, SelectionDAG &DAG,
-                                   bool LegalOperations, bool ForCodeSize,
-                                   unsigned Depth) const override;
-
-    /// If getNegatibleCost returns Neutral/Cheaper, return the newly negated
-    /// expression.
-    SDValue negateExpression(SDValue Op, SelectionDAG &DAG,
-                             bool LegalOperations, bool ForCodeSize,
-                             unsigned Depth) const override;
+    /// Return the newly negated expression if the cost is not expensive and
+    /// set the cost in \p Cost to indicate that if it is cheaper or neutral to
+    /// do the negation.
+    SDValue getNegatedExpression(SDValue Op, SelectionDAG &DAG,
+                                 bool LegalOperations, bool ForCodeSize,
+                                 NegatibleCost &Cost,
+                                 unsigned Depth) const override;
 
     MachineBasicBlock *
     EmitInstrWithCustomInserter(MachineInstr &MI,

diff  --git a/llvm/test/CodeGen/X86/neg_fp.ll b/llvm/test/CodeGen/X86/neg_fp.ll
index 3c04aafcea4d..e07157c952f7 100644
--- a/llvm/test/CodeGen/X86/neg_fp.ll
+++ b/llvm/test/CodeGen/X86/neg_fp.ll
@@ -54,3 +54,30 @@ define double @negation_propagation(double* %arg, double %arg1, double %arg2) no
   %t18 = fadd double %t16, %t7
   ret double %t18
 }
+
+; This would crash because the negated expression for %sub4
+; creates a new use of %sub1 and that alters the negated cost
+
+define float @fdiv_extra_use_changes_cost(float %a0, float %a1, float %a2) nounwind {
+; CHECK-LABEL: fdiv_extra_use_changes_cost:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    pushl %eax
+; CHECK-NEXT:    movss {{.*#+}} xmm0 = mem[0],zero,zero,zero
+; CHECK-NEXT:    movss {{.*#+}} xmm1 = mem[0],zero,zero,zero
+; CHECK-NEXT:    subss {{[0-9]+}}(%esp), %xmm1
+; CHECK-NEXT:    movaps %xmm1, %xmm2
+; CHECK-NEXT:    mulss %xmm0, %xmm2
+; CHECK-NEXT:    subss %xmm1, %xmm0
+; CHECK-NEXT:    divss %xmm2, %xmm0
+; CHECK-NEXT:    movss %xmm0, (%esp)
+; CHECK-NEXT:    flds (%esp)
+; CHECK-NEXT:    popl %eax
+; CHECK-NEXT:    retl
+  %sub1 = fsub fast float %a0, %a1
+  %mul2 = fmul fast float %sub1, %a2
+  %neg = fneg fast float %a0
+  %add3 = fadd fast float %a1, %neg
+  %sub4 = fadd fast float %add3, %a2
+  %div5 = fdiv fast float %sub4, %mul2
+  ret float %div5
+}


        


More information about the llvm-commits mailing list