[llvm] ffc3e80 - [NFC] [DAGCombine] Correct the result for sqrt even the iteration is zero

QingShan Zhang via llvm-commits llvm-commits at lists.llvm.org
Sun Jan 24 20:02:53 PST 2021


Author: QingShan Zhang
Date: 2021-01-25T04:02:44Z
New Revision: ffc3e800c65ee58166255ff897f8b7e6d850ddda

URL: https://github.com/llvm/llvm-project/commit/ffc3e800c65ee58166255ff897f8b7e6d850ddda
DIFF: https://github.com/llvm/llvm-project/commit/ffc3e800c65ee58166255ff897f8b7e6d850ddda.diff

LOG: [NFC] [DAGCombine] Correct the result for sqrt even the iteration is zero

For now, we correct the result for sqrt if iteration > 0. This doesn't make
sense as they are not strict relative.

Reviewed By: dmgreen, spatel, RKSimon

Differential Revision: https://reviews.llvm.org/D94480

Added: 
    

Modified: 
    llvm/include/llvm/CodeGen/TargetLowering.h
    llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
    llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
    llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
    llvm/lib/Target/AArch64/AArch64ISelLowering.h
    llvm/lib/Target/PowerPC/PPCISelLowering.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index 5a237074a5a3..1bc5377e6863 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -4287,9 +4287,7 @@ class TargetLowering : public TargetLoweringBase {
   /// comparison may check if the operand is NAN, INF, zero, normal, etc. The
   /// result should be used as the condition operand for a select or branch.
   virtual SDValue getSqrtInputTest(SDValue Operand, SelectionDAG &DAG,
-                                   const DenormalMode &Mode) const {
-    return SDValue();
-  }
+                                   const DenormalMode &Mode) const;
 
   /// Return a target-dependent result if the input operand is not suitable for
   /// use with a square root estimate calculation.

diff  --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 2ebf7c6ba0f3..cb273a6f299c 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -22275,43 +22275,21 @@ SDValue DAGCombiner::buildSqrtEstimateImpl(SDValue Op, SDNodeFlags Flags,
                           Reciprocal)) {
     AddToWorklist(Est.getNode());
 
-    if (Iterations) {
+    if (Iterations)
       Est = UseOneConstNR
             ? buildSqrtNROneConst(Op, Est, Iterations, Flags, Reciprocal)
             : buildSqrtNRTwoConst(Op, Est, Iterations, Flags, Reciprocal);
-
-      if (!Reciprocal) {
-        SDLoc DL(Op);
-        EVT CCVT = getSetCCResultType(VT);
-        SDValue FPZero = DAG.getConstantFP(0.0, DL, VT);
-        DenormalMode DenormMode = DAG.getDenormalMode(VT);
-        // Try the target specific test first.
-        SDValue Test = TLI.getSqrtInputTest(Op, DAG, DenormMode);
-        if (!Test) {
-          // If no test provided by target, testing it with denormal inputs to
-          // avoid wrong estimate.
-          if (DenormMode.Input == DenormalMode::IEEE) {
-            // This is specifically a check for the handling of denormal inputs,
-            // not the result.
-
-            // Test = fabs(X) < SmallestNormal
-            const fltSemantics &FltSem = DAG.EVTToAPFloatSemantics(VT);
-            APFloat SmallestNorm = APFloat::getSmallestNormalized(FltSem);
-            SDValue NormC = DAG.getConstantFP(SmallestNorm, DL, VT);
-            SDValue Fabs = DAG.getNode(ISD::FABS, DL, VT, Op);
-            Test = DAG.getSetCC(DL, CCVT, Fabs, NormC, ISD::SETLT);
-          } else
-            // Test = X == 0.0
-            Test = DAG.getSetCC(DL, CCVT, Op, FPZero, ISD::SETEQ);
-        }
-
-        // The estimate is now completely wrong if the input was exactly 0.0 or
-        // possibly a denormal. Force the answer to 0.0 or value provided by
-        // target for those cases.
-        Est = DAG.getNode(
-            Test.getValueType().isVector() ? ISD::VSELECT : ISD::SELECT, DL, VT,
-            Test, TLI.getSqrtResultForDenormInput(Op, DAG), Est);
-      }
+    if (!Reciprocal) {
+      SDLoc DL(Op);
+      // Try the target specific test first.
+      SDValue Test = TLI.getSqrtInputTest(Op, DAG, DAG.getDenormalMode(VT));
+
+      // The estimate is now completely wrong if the input was exactly 0.0 or
+      // possibly a denormal. Force the answer to 0.0 or value provided by
+      // target for those cases.
+      Est = DAG.getNode(
+          Test.getValueType().isVector() ? ISD::VSELECT : ISD::SELECT, DL, VT,
+          Test, TLI.getSqrtResultForDenormInput(Op, DAG), Est);
     }
     return Est;
   }

diff  --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index 80b745e0354a..7858bc6c43e4 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -5841,6 +5841,28 @@ verifyReturnAddressArgumentIsConstant(SDValue Op, SelectionDAG &DAG) const {
   return false;
 }
 
+SDValue TargetLowering::getSqrtInputTest(SDValue Op, SelectionDAG &DAG,
+                                         const DenormalMode &Mode) const {
+  SDLoc DL(Op);
+  EVT VT = Op.getValueType();
+  EVT CCVT = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);
+  SDValue FPZero = DAG.getConstantFP(0.0, DL, VT);
+  // Testing it with denormal inputs to avoid wrong estimate.
+  if (Mode.Input == DenormalMode::IEEE) {
+    // This is specifically a check for the handling of denormal inputs,
+    // not the result.
+
+    // Test = fabs(X) < SmallestNormal
+    const fltSemantics &FltSem = DAG.EVTToAPFloatSemantics(VT);
+    APFloat SmallestNorm = APFloat::getSmallestNormalized(FltSem);
+    SDValue NormC = DAG.getConstantFP(SmallestNorm, DL, VT);
+    SDValue Fabs = DAG.getNode(ISD::FABS, DL, VT, Op);
+    return DAG.getSetCC(DL, CCVT, Fabs, NormC, ISD::SETLT);
+  }
+  // Test = X == 0.0
+  return DAG.getSetCC(DL, CCVT, Op, FPZero, ISD::SETEQ);
+}
+
 SDValue TargetLowering::getNegatedExpression(SDValue Op, SelectionDAG &DAG,
                                              bool LegalOps, bool OptForSize,
                                              NegatibleCost &Cost,

diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index f485490199fd..1be09186dc0a 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -7471,6 +7471,22 @@ static SDValue getEstimate(const AArch64Subtarget *ST, unsigned Opcode,
   return SDValue();
 }
 
+SDValue
+AArch64TargetLowering::getSqrtInputTest(SDValue Op, SelectionDAG &DAG,
+                                        const DenormalMode &Mode) const {
+  SDLoc DL(Op);
+  EVT VT = Op.getValueType();
+  EVT CCVT = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);
+  SDValue FPZero = DAG.getConstantFP(0.0, DL, VT);
+  return DAG.getSetCC(DL, CCVT, Op, FPZero, ISD::SETEQ);
+}
+
+SDValue
+AArch64TargetLowering::getSqrtResultForDenormInput(SDValue Op,
+                                                   SelectionDAG &DAG) const {
+  return Op;
+}
+
 SDValue AArch64TargetLowering::getSqrtEstimate(SDValue Operand,
                                                SelectionDAG &DAG, int Enabled,
                                                int &ExtraSteps,
@@ -7494,17 +7510,8 @@ SDValue AArch64TargetLowering::getSqrtEstimate(SDValue Operand,
         Step = DAG.getNode(AArch64ISD::FRSQRTS, DL, VT, Operand, Step, Flags);
         Estimate = DAG.getNode(ISD::FMUL, DL, VT, Estimate, Step, Flags);
       }
-      if (!Reciprocal) {
-        EVT CCVT = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(),
-                                      VT);
-        SDValue FPZero = DAG.getConstantFP(0.0, DL, VT);
-        SDValue Eq = DAG.getSetCC(DL, CCVT, Operand, FPZero, ISD::SETEQ);
-
+      if (!Reciprocal)
         Estimate = DAG.getNode(ISD::FMUL, DL, VT, Operand, Estimate, Flags);
-        // Correct the result if the operand is 0.0.
-        Estimate = DAG.getNode(VT.isVector() ? ISD::VSELECT : ISD::SELECT, DL,
-                               VT, Eq, Operand, Estimate);
-      }
 
       ExtraSteps = 0;
       return Estimate;

diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index 71e59f8f6ed7..9550197159e6 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -961,6 +961,10 @@ class AArch64TargetLowering : public TargetLowering {
                           bool Reciprocal) const override;
   SDValue getRecipEstimate(SDValue Operand, SelectionDAG &DAG, int Enabled,
                            int &ExtraSteps) const override;
+  SDValue getSqrtInputTest(SDValue Operand, SelectionDAG &DAG,
+                           const DenormalMode &Mode) const override;
+  SDValue getSqrtResultForDenormInput(SDValue Operand,
+                                      SelectionDAG &DAG) const override;
   unsigned combineRepeatedFPDivisors() const override;
 
   ConstraintType getConstraintType(StringRef Constraint) const override;

diff  --git a/llvm/lib/Target/PowerPC/PPCISelLowering.cpp b/llvm/lib/Target/PowerPC/PPCISelLowering.cpp
index a6382bd17a9f..9215c17cb94b 100644
--- a/llvm/lib/Target/PowerPC/PPCISelLowering.cpp
+++ b/llvm/lib/Target/PowerPC/PPCISelLowering.cpp
@@ -12133,7 +12133,7 @@ SDValue PPCTargetLowering::getSqrtInputTest(SDValue Op, SelectionDAG &DAG,
   if (!isTypeLegal(MVT::i1) ||
       (VT != MVT::f64 &&
        ((VT != MVT::v2f64 && VT != MVT::v4f32) || !Subtarget.hasVSX())))
-    return SDValue();
+    return TargetLowering::getSqrtInputTest(Op, DAG, Mode);
 
   SDLoc DL(Op);
   // The output register of FTSQRT is CR field.


        


More information about the llvm-commits mailing list