[llvm] r323981 - [DAGCombiner] filter out denorm inputs when calculating sqrt estimate (PR34994)

Sanjay Patel via llvm-commits llvm-commits at lists.llvm.org
Thu Feb 1 08:57:18 PST 2018


Author: spatel
Date: Thu Feb  1 08:57:18 2018
New Revision: 323981

URL: http://llvm.org/viewvc/llvm-project?rev=323981&view=rev
Log:
[DAGCombiner] filter out denorm inputs when calculating sqrt estimate (PR34994)

As shown in the example in PR34994:
https://bugs.llvm.org/show_bug.cgi?id=34994
...we can return a very wrong answer (inf instead of 0.0) for square root when 
using a reciprocal square root estimate instruction.

Here, I've conditionalized the filtering out of denorms based on the function 
having "denormal-fp-math"="ieee" in its attributes. The other options for this 
attribute are 'preserve-sign' and 'positive-zero'.

So we don't generate this extra code by default with just '-ffast-math' (because 
then there's no denormal attribute string at all), but it works if you specify 
'-ffast-math -fdenormal-fp-math=ieee' from clang. 

As noted in the review, there may be other problems in clang that affect the 
results depending on platform (Linux x86 at least), but this should allow 
creating the desired codegen.

Differential Revision: https://reviews.llvm.org/D42323


Modified:
    llvm/trunk/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
    llvm/trunk/lib/Target/AArch64/AArch64ISelLowering.cpp
    llvm/trunk/test/CodeGen/X86/sqrt-fastmath.ll

Modified: llvm/trunk/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/CodeGen/SelectionDAG/DAGCombiner.cpp?rev=323981&r1=323980&r2=323981&view=diff
==============================================================================
--- llvm/trunk/lib/CodeGen/SelectionDAG/DAGCombiner.cpp (original)
+++ llvm/trunk/lib/CodeGen/SelectionDAG/DAGCombiner.cpp Thu Feb  1 08:57:18 2018
@@ -17454,19 +17454,34 @@ SDValue DAGCombiner::buildSqrtEstimateIm
             : buildSqrtNRTwoConst(Op, Est, Iterations, Flags, Reciprocal);
 
       if (!Reciprocal) {
-        // Unfortunately, Est is now NaN if the input was exactly 0.0.
-        // Select out this case and force the answer to 0.0.
+        // The estimate is now completely wrong if the input was exactly 0.0 or
+        // possibly a denormal. Force the answer to 0.0 for those cases.
         EVT VT = Op.getValueType();
         SDLoc DL(Op);
-
-        SDValue FPZero = DAG.getConstantFP(0.0, DL, VT);
         EVT CCVT = getSetCCResultType(VT);
-        SDValue ZeroCmp = DAG.getSetCC(DL, CCVT, Op, FPZero, ISD::SETEQ);
-        AddToWorklist(ZeroCmp.getNode());
-
-        Est = DAG.getNode(VT.isVector() ? ISD::VSELECT : ISD::SELECT, DL, VT,
-                          ZeroCmp, FPZero, Est);
-        AddToWorklist(Est.getNode());
+        ISD::NodeType SelOpcode = VT.isVector() ? ISD::VSELECT : ISD::SELECT;
+        const Function &F = DAG.getMachineFunction().getFunction();
+        Attribute Denorms = F.getFnAttribute("denormal-fp-math");
+        if (Denorms.getValueAsString().equals("ieee")) {
+          // fabs(X) < SmallestNormal ? 0.0 : Est
+          const fltSemantics &FltSem = DAG.EVTToAPFloatSemantics(VT);
+          APFloat SmallestNorm = APFloat::getSmallestNormalized(FltSem);
+          SDValue NormC = DAG.getConstantFP(SmallestNorm, DL, VT);
+          SDValue FPZero = DAG.getConstantFP(0.0, DL, VT);
+          SDValue Fabs = DAG.getNode(ISD::FABS, DL, VT, Op);
+          SDValue IsDenorm = DAG.getSetCC(DL, CCVT, Fabs, NormC, ISD::SETLT);
+          Est = DAG.getNode(SelOpcode, DL, VT, IsDenorm, FPZero, Est);
+          AddToWorklist(Fabs.getNode());
+          AddToWorklist(IsDenorm.getNode());
+          AddToWorklist(Est.getNode());
+        } else {
+          // X == 0.0 ? 0.0 : Est
+          SDValue FPZero = DAG.getConstantFP(0.0, DL, VT);
+          SDValue IsZero = DAG.getSetCC(DL, CCVT, Op, FPZero, ISD::SETEQ);
+          Est = DAG.getNode(SelOpcode, DL, VT, IsZero, FPZero, Est);
+          AddToWorklist(IsZero.getNode());
+          AddToWorklist(Est.getNode());
+        }
       }
     }
     return Est;

Modified: llvm/trunk/lib/Target/AArch64/AArch64ISelLowering.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Target/AArch64/AArch64ISelLowering.cpp?rev=323981&r1=323980&r2=323981&view=diff
==============================================================================
--- llvm/trunk/lib/Target/AArch64/AArch64ISelLowering.cpp (original)
+++ llvm/trunk/lib/Target/AArch64/AArch64ISelLowering.cpp Thu Feb  1 08:57:18 2018
@@ -5007,7 +5007,9 @@ SDValue AArch64TargetLowering::getSqrtEs
         Step = DAG.getNode(AArch64ISD::FRSQRTS, DL, VT, Operand, Step, Flags);
         Estimate = DAG.getNode(ISD::FMUL, DL, VT, Estimate, Step, Flags);
       }
-
+      // FIXME: This does not detect denorm inputs, so we might produce INF
+      // when we should produce 0.0. Try to refactor the code in DAGCombiner,
+      // so we don't have to duplicate it here.
       if (!Reciprocal) {
         EVT CCVT = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(),
                                       VT);

Modified: llvm/trunk/test/CodeGen/X86/sqrt-fastmath.ll
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/test/CodeGen/X86/sqrt-fastmath.ll?rev=323981&r1=323980&r2=323981&view=diff
==============================================================================
--- llvm/trunk/test/CodeGen/X86/sqrt-fastmath.ll (original)
+++ llvm/trunk/test/CodeGen/X86/sqrt-fastmath.ll Thu Feb  1 08:57:18 2018
@@ -121,8 +121,8 @@ define float @sqrtf_check_denorms(float
 ; SSE-NEXT:    mulss %xmm1, %xmm2
 ; SSE-NEXT:    addss {{.*}}(%rip), %xmm2
 ; SSE-NEXT:    mulss %xmm3, %xmm2
-; SSE-NEXT:    xorps %xmm1, %xmm1
-; SSE-NEXT:    cmpeqss %xmm1, %xmm0
+; SSE-NEXT:    andps {{.*}}(%rip), %xmm0
+; SSE-NEXT:    cmpltss {{.*}}(%rip), %xmm0
 ; SSE-NEXT:    andnps %xmm2, %xmm0
 ; SSE-NEXT:    retq
 ;
@@ -134,8 +134,8 @@ define float @sqrtf_check_denorms(float
 ; AVX-NEXT:    vaddss {{.*}}(%rip), %xmm1, %xmm1
 ; AVX-NEXT:    vmulss {{.*}}(%rip), %xmm2, %xmm2
 ; AVX-NEXT:    vmulss %xmm1, %xmm2, %xmm1
-; AVX-NEXT:    vxorps %xmm2, %xmm2, %xmm2
-; AVX-NEXT:    vcmpeqss %xmm2, %xmm0, %xmm0
+; AVX-NEXT:    vandps {{.*}}(%rip), %xmm0, %xmm0
+; AVX-NEXT:    vcmpltss {{.*}}(%rip), %xmm0, %xmm0
 ; AVX-NEXT:    vandnps %xmm1, %xmm0, %xmm0
 ; AVX-NEXT:    retq
   %call = tail call float @__sqrtf_finite(float %x) #2
@@ -145,17 +145,19 @@ define float @sqrtf_check_denorms(float
 define <4 x float> @sqrt_v4f32_check_denorms(<4 x float> %x) #3 {
 ; SSE-LABEL: sqrt_v4f32_check_denorms:
 ; SSE:       # %bb.0:
-; SSE-NEXT:    rsqrtps %xmm0, %xmm1
-; SSE-NEXT:    movaps %xmm0, %xmm2
-; SSE-NEXT:    mulps %xmm1, %xmm2
+; SSE-NEXT:    rsqrtps %xmm0, %xmm2
+; SSE-NEXT:    movaps %xmm0, %xmm1
+; SSE-NEXT:    mulps %xmm2, %xmm1
 ; SSE-NEXT:    movaps {{.*#+}} xmm3 = [-5.000000e-01,-5.000000e-01,-5.000000e-01,-5.000000e-01]
-; SSE-NEXT:    mulps %xmm2, %xmm3
-; SSE-NEXT:    mulps %xmm1, %xmm2
-; SSE-NEXT:    addps {{.*}}(%rip), %xmm2
-; SSE-NEXT:    mulps %xmm3, %xmm2
-; SSE-NEXT:    xorps %xmm1, %xmm1
-; SSE-NEXT:    cmpneqps %xmm1, %xmm0
-; SSE-NEXT:    andps %xmm2, %xmm0
+; SSE-NEXT:    mulps %xmm1, %xmm3
+; SSE-NEXT:    mulps %xmm2, %xmm1
+; SSE-NEXT:    addps {{.*}}(%rip), %xmm1
+; SSE-NEXT:    mulps %xmm3, %xmm1
+; SSE-NEXT:    andps {{.*}}(%rip), %xmm0
+; SSE-NEXT:    movaps {{.*#+}} xmm2 = [1.175494e-38,1.175494e-38,1.175494e-38,1.175494e-38]
+; SSE-NEXT:    cmpleps %xmm0, %xmm2
+; SSE-NEXT:    andps %xmm2, %xmm1
+; SSE-NEXT:    movaps %xmm1, %xmm0
 ; SSE-NEXT:    retq
 ;
 ; AVX-LABEL: sqrt_v4f32_check_denorms:
@@ -166,8 +168,9 @@ define <4 x float> @sqrt_v4f32_check_den
 ; AVX-NEXT:    vmulps %xmm1, %xmm2, %xmm1
 ; AVX-NEXT:    vaddps {{.*}}(%rip), %xmm1, %xmm1
 ; AVX-NEXT:    vmulps %xmm1, %xmm3, %xmm1
-; AVX-NEXT:    vxorps %xmm2, %xmm2, %xmm2
-; AVX-NEXT:    vcmpneqps %xmm2, %xmm0, %xmm0
+; AVX-NEXT:    vandps {{.*}}(%rip), %xmm0, %xmm0
+; AVX-NEXT:    vmovaps {{.*#+}} xmm2 = [1.175494e-38,1.175494e-38,1.175494e-38,1.175494e-38]
+; AVX-NEXT:    vcmpleps %xmm0, %xmm2, %xmm0
 ; AVX-NEXT:    vandps %xmm1, %xmm0, %xmm0
 ; AVX-NEXT:    retq
   %call = tail call <4 x float> @llvm.sqrt.v4f32(<4 x float> %x) #2




More information about the llvm-commits mailing list