[llvm] d1f0bdf - [SDAG] remove use restriction in isNegatibleForFree() when called from getNegatedExpression()

Sanjay Patel via llvm-commits llvm-commits at lists.llvm.org
Wed Dec 11 10:30:47 PST 2019


Author: Sanjay Patel
Date: 2019-12-11T13:30:39-05:00
New Revision: d1f0bdf2d2df9bdf11ee2ddfff3df50e53f2f042

URL: https://github.com/llvm/llvm-project/commit/d1f0bdf2d2df9bdf11ee2ddfff3df50e53f2f042
DIFF: https://github.com/llvm/llvm-project/commit/d1f0bdf2d2df9bdf11ee2ddfff3df50e53f2f042.diff

LOG: [SDAG] remove use restriction in isNegatibleForFree() when called from getNegatedExpression()

This is an alternate fix for the bug discussed in D70595.
This also includes minimal tests for other in-tree targets
to show the problem more generally.

We check the number of uses as a predicate for whether some
value is free to negate, but that use count can change as we
rewrite the expression in getNegatedExpression(). So something
that was marked free to negate during the cost evaluation
phase becomes not free to negate during the rewrite phase (or
the inverse - something that was not free becomes free).
This can lead to a crash/assert because we expect that
everything in an expression that is negatible to be handled
in the corresponding code within getNegatedExpression().

This patch skips the use check during the rewrite phase.
So we determine that some expression isNegatibleForFree
(identically to without this patch), but during the rewrite,
don't rely on use counts to decide how to create the optimal
expression.

Differential Revision: https://reviews.llvm.org/D70975

Added: 
    

Modified: 
    llvm/include/llvm/CodeGen/TargetLowering.h
    llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
    llvm/lib/Target/X86/X86ISelLowering.cpp
    llvm/lib/Target/X86/X86ISelLowering.h
    llvm/test/CodeGen/AArch64/arm64-fmadd.ll
    llvm/test/CodeGen/X86/fma-fneg-combine-2.ll

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index 0726bdfec20e..687a2eb9296f 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -3442,8 +3442,16 @@ class TargetLowering : public TargetLoweringBase {
   /// Return 1 if we can compute the negated form of the specified expression
   /// for the same cost as the expression itself, or 2 if we can compute the
   /// negated form more cheaply than the expression itself. Else return 0.
+  ///
+  /// EnableUseCheck specifies whether the number of uses of a value affects
+  /// if negation is considered free. This is needed because the number of uses
+  /// of any value may change as we rewrite the expression. Therefore, when
+  /// called from getNegatedExpression(), we must explicitly set EnableUseCheck
+  /// to false to avoid getting a 
diff erent answer than when called from other
+  /// contexts.
   virtual char isNegatibleForFree(SDValue Op, SelectionDAG &DAG,
                                   bool LegalOperations, bool ForCodeSize,
+                                  bool EnableUseCheck = true,
                                   unsigned Depth = 0) const;
 
   /// If isNegatibleForFree returns true, return the newly negated expression.

diff  --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index f8afdaf086ab..05011aebb9d3 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -5413,18 +5413,21 @@ verifyReturnAddressArgumentIsConstant(SDValue Op, SelectionDAG &DAG) const {
 
 char TargetLowering::isNegatibleForFree(SDValue Op, SelectionDAG &DAG,
                                         bool LegalOperations, bool ForCodeSize,
+                                        bool EnableUseCheck,
                                         unsigned Depth) const {
   // fneg is removable even if it has multiple uses.
   if (Op.getOpcode() == ISD::FNEG)
     return 2;
 
-  // Don't allow anything with multiple uses unless we know it is free.
+  // If the caller requires checking uses, don't allow anything with multiple
+  // uses unless we know it is free.
   EVT VT = Op.getValueType();
   const SDNodeFlags Flags = Op->getFlags();
   const TargetOptions &Options = DAG.getTarget().Options;
-  if (!Op.hasOneUse() && !(Op.getOpcode() == ISD::FP_EXTEND &&
-                           isFPExtFree(VT, Op.getOperand(0).getValueType())))
-    return 0;
+  if (EnableUseCheck)
+    if (!Op.hasOneUse() && !(Op.getOpcode() == ISD::FP_EXTEND &&
+                             isFPExtFree(VT, Op.getOperand(0).getValueType())))
+      return 0;
 
   // Don't recurse exponentially.
   if (Depth > SelectionDAG::MaxRecursionDepth)
@@ -5468,11 +5471,11 @@ char TargetLowering::isNegatibleForFree(SDValue Op, SelectionDAG &DAG,
 
     // fold (fneg (fadd A, B)) -> (fsub (fneg A), B)
     if (char V = isNegatibleForFree(Op.getOperand(0), DAG, LegalOperations,
-                                    ForCodeSize, Depth + 1))
+                                    ForCodeSize, EnableUseCheck, Depth + 1))
       return V;
     // fold (fneg (fadd A, B)) -> (fsub (fneg B), A)
     return isNegatibleForFree(Op.getOperand(1), DAG, LegalOperations,
-                              ForCodeSize, Depth + 1);
+                              ForCodeSize, EnableUseCheck, Depth + 1);
   case ISD::FSUB:
     // We can't turn -(A-B) into B-A when we honor signed zeros.
     if (!Options.NoSignedZerosFPMath && !Flags.hasNoSignedZeros())
@@ -5485,7 +5488,7 @@ char TargetLowering::isNegatibleForFree(SDValue Op, SelectionDAG &DAG,
   case ISD::FDIV:
     // fold (fneg (fmul X, Y)) -> (fmul (fneg X), Y) or (fmul X, (fneg Y))
     if (char V = isNegatibleForFree(Op.getOperand(0), DAG, LegalOperations,
-                                    ForCodeSize, Depth + 1))
+                                    ForCodeSize, EnableUseCheck, Depth + 1))
       return V;
 
     // Ignore X * 2.0 because that is expected to be canonicalized to X + X.
@@ -5494,7 +5497,7 @@ char TargetLowering::isNegatibleForFree(SDValue Op, SelectionDAG &DAG,
         return 0;
 
     return isNegatibleForFree(Op.getOperand(1), DAG, LegalOperations,
-                              ForCodeSize, Depth + 1);
+                              ForCodeSize, EnableUseCheck, Depth + 1);
 
   case ISD::FMA:
   case ISD::FMAD: {
@@ -5504,15 +5507,15 @@ char TargetLowering::isNegatibleForFree(SDValue Op, SelectionDAG &DAG,
     // fold (fneg (fma X, Y, Z)) -> (fma (fneg X), Y, (fneg Z))
     // fold (fneg (fma X, Y, Z)) -> (fma X, (fneg Y), (fneg Z))
     char V2 = isNegatibleForFree(Op.getOperand(2), DAG, LegalOperations,
-                                 ForCodeSize, Depth + 1);
+                                 ForCodeSize, EnableUseCheck, Depth + 1);
     if (!V2)
       return 0;
 
     // One of Op0/Op1 must be cheaply negatible, then select the cheapest.
     char V0 = isNegatibleForFree(Op.getOperand(0), DAG, LegalOperations,
-                                 ForCodeSize, Depth + 1);
+                                 ForCodeSize, EnableUseCheck, Depth + 1);
     char V1 = isNegatibleForFree(Op.getOperand(1), DAG, LegalOperations,
-                                 ForCodeSize, Depth + 1);
+                                 ForCodeSize, EnableUseCheck, Depth + 1);
     char V01 = std::max(V0, V1);
     return V01 ? std::max(V01, V2) : 0;
   }
@@ -5521,7 +5524,7 @@ char TargetLowering::isNegatibleForFree(SDValue Op, SelectionDAG &DAG,
   case ISD::FP_ROUND:
   case ISD::FSIN:
     return isNegatibleForFree(Op.getOperand(0), DAG, LegalOperations,
-                              ForCodeSize, Depth + 1);
+                              ForCodeSize, EnableUseCheck, Depth + 1);
   }
 
   return 0;
@@ -5565,7 +5568,7 @@ SDValue TargetLowering::getNegatedExpression(SDValue Op, SelectionDAG &DAG,
 
     // fold (fneg (fadd A, B)) -> (fsub (fneg A), B)
     if (isNegatibleForFree(Op.getOperand(0), DAG, LegalOperations, ForCodeSize,
-                           Depth + 1))
+                           false, Depth + 1))
       return DAG.getNode(ISD::FSUB, SDLoc(Op), Op.getValueType(),
                          getNegatedExpression(Op.getOperand(0), DAG,
                                               LegalOperations, ForCodeSize,
@@ -5592,7 +5595,7 @@ SDValue TargetLowering::getNegatedExpression(SDValue Op, SelectionDAG &DAG,
   case ISD::FDIV:
     // fold (fneg (fmul X, Y)) -> (fmul (fneg X), Y)
     if (isNegatibleForFree(Op.getOperand(0), DAG, LegalOperations, ForCodeSize,
-                           Depth + 1))
+                           false, Depth + 1))
       return DAG.getNode(Op.getOpcode(), SDLoc(Op), Op.getValueType(),
                          getNegatedExpression(Op.getOperand(0), DAG,
                                               LegalOperations, ForCodeSize,
@@ -5616,9 +5619,9 @@ SDValue TargetLowering::getNegatedExpression(SDValue Op, SelectionDAG &DAG,
                                         ForCodeSize, Depth + 1);
 
     char V0 = isNegatibleForFree(Op.getOperand(0), DAG, LegalOperations,
-                                 ForCodeSize, Depth + 1);
+                                 ForCodeSize, false, Depth + 1);
     char V1 = isNegatibleForFree(Op.getOperand(1), DAG, LegalOperations,
-                                 ForCodeSize, Depth + 1);
+                                 ForCodeSize, false, Depth + 1);
     if (V0 >= V1) {
       // fold (fneg (fma X, Y, Z)) -> (fma (fneg X), Y, (fneg Z))
       SDValue Neg0 = getNegatedExpression(

diff  --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 866ee5b9a602..cdb588ddb8a2 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -41898,6 +41898,7 @@ static SDValue combineFneg(SDNode *N, SelectionDAG &DAG,
 char X86TargetLowering::isNegatibleForFree(SDValue Op, SelectionDAG &DAG,
                                            bool LegalOperations,
                                            bool ForCodeSize,
+                                           bool EnableUseCheck,
                                            unsigned Depth) const {
   // fneg patterns are removable even if they have multiple uses.
   if (isFNEG(DAG, Op.getNode(), Depth))
@@ -41926,7 +41927,7 @@ char X86TargetLowering::isNegatibleForFree(SDValue Op, SelectionDAG &DAG,
     // extra operand negations as well.
     for (int i = 0; i != 3; ++i) {
       char V = isNegatibleForFree(Op.getOperand(i), DAG, LegalOperations,
-                                  ForCodeSize, Depth + 1);
+                                  ForCodeSize, EnableUseCheck, Depth + 1);
       if (V == 2)
         return V;
     }
@@ -41935,7 +41936,8 @@ char X86TargetLowering::isNegatibleForFree(SDValue Op, SelectionDAG &DAG,
   }
 
   return TargetLowering::isNegatibleForFree(Op, DAG, LegalOperations,
-                                            ForCodeSize, Depth);
+                                            ForCodeSize, EnableUseCheck,
+                                            Depth);
 }
 
 SDValue X86TargetLowering::getNegatedExpression(SDValue Op, SelectionDAG &DAG,
@@ -41967,7 +41969,7 @@ SDValue X86TargetLowering::getNegatedExpression(SDValue Op, SelectionDAG &DAG,
     SmallVector<SDValue, 4> NewOps(Op.getNumOperands(), SDValue());
     for (int i = 0; i != 3; ++i) {
       char V = isNegatibleForFree(Op.getOperand(i), DAG, LegalOperations,
-                                  ForCodeSize, Depth + 1);
+                                  ForCodeSize, false, Depth + 1);
       if (V == 2)
         NewOps[i] = getNegatedExpression(Op.getOperand(i), DAG, LegalOperations,
                                          ForCodeSize, Depth + 1);

diff  --git a/llvm/lib/Target/X86/X86ISelLowering.h b/llvm/lib/Target/X86/X86ISelLowering.h
index 016120064134..3bbf3b59ac5a 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.h
+++ b/llvm/lib/Target/X86/X86ISelLowering.h
@@ -809,7 +809,8 @@ namespace llvm {
     /// for the same cost as the expression itself, or 2 if we can compute the
     /// negated form more cheaply than the expression itself. Else return 0.
     char isNegatibleForFree(SDValue Op, SelectionDAG &DAG, bool LegalOperations,
-                            bool ForCodeSize, unsigned Depth) const override;
+                            bool ForCodeSize, bool EnableUseCheck,
+                            unsigned Depth) const override;
 
     /// If isNegatibleForFree returns true, return the newly negated expression.
     SDValue getNegatedExpression(SDValue Op, SelectionDAG &DAG,

diff  --git a/llvm/test/CodeGen/AArch64/arm64-fmadd.ll b/llvm/test/CodeGen/AArch64/arm64-fmadd.ll
index 203ce623647f..dffa83aa11b2 100644
--- a/llvm/test/CodeGen/AArch64/arm64-fmadd.ll
+++ b/llvm/test/CodeGen/AArch64/arm64-fmadd.ll
@@ -88,5 +88,23 @@ entry:
   ret double %0
 }
 
+; This would crash while trying getNegatedExpression().
+
+define float @negated_constant(float %x) {
+; CHECK-LABEL: negated_constant:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    mov w8, #-1037565952
+; CHECK-NEXT:    mov w9, #1109917696
+; CHECK-NEXT:    fmov s1, w8
+; CHECK-NEXT:    fmul s1, s0, s1
+; CHECK-NEXT:    fmov s2, w9
+; CHECK-NEXT:    fmadd s0, s0, s2, s1
+; CHECK-NEXT:    ret
+  %m = fmul float %x, 42.0
+  %fma = call nsz float @llvm.fma.f32(float %x, float -42.0, float %m)
+  %nfma = fneg float %fma
+  ret float %nfma
+}
+
 declare float @llvm.fma.f32(float, float, float) nounwind readnone
 declare double @llvm.fma.f64(double, double, double) nounwind readnone

diff  --git a/llvm/test/CodeGen/X86/fma-fneg-combine-2.ll b/llvm/test/CodeGen/X86/fma-fneg-combine-2.ll
index f9e87955270b..9c846e3f555c 100644
--- a/llvm/test/CodeGen/X86/fma-fneg-combine-2.ll
+++ b/llvm/test/CodeGen/X86/fma-fneg-combine-2.ll
@@ -86,4 +86,24 @@ entry:
   ret float %1
 }
 
+; This would crash while trying getNegatedExpression().
+
+define float @negated_constant(float %x) {
+; FMA3-LABEL: negated_constant:
+; FMA3:       # %bb.0:
+; FMA3-NEXT:    vmulss {{.*}}(%rip), %xmm0, %xmm1
+; FMA3-NEXT:    vfmadd132ss {{.*#+}} xmm0 = (xmm0 * mem) + xmm1
+; FMA3-NEXT:    retq
+;
+; FMA4-LABEL: negated_constant:
+; FMA4:       # %bb.0:
+; FMA4-NEXT:    vmulss {{.*}}(%rip), %xmm0, %xmm1
+; FMA4-NEXT:    vfmaddss %xmm1, {{.*}}(%rip), %xmm0, %xmm0
+; FMA4-NEXT:    retq
+  %m = fmul float %x, 42.0
+  %fma = call nsz float @llvm.fma.f32(float %x, float -42.0, float %m)
+  %nfma = fneg float %fma
+  ret float %nfma
+}
+
 declare float @llvm.fma.f32(float, float, float)


        


More information about the llvm-commits mailing list