[llvm] 65420c8 - DAG: Use getNegatedExpression in combineMinNumMaxNum

Matt Arsenault via llvm-commits llvm-commits at lists.llvm.org
Mon Jan 23 02:07:31 PST 2023


Author: Matt Arsenault
Date: 2023-01-23T06:07:23-04:00
New Revision: 65420c8041f4ca44a3a14c5f7faf426ee6a7c6a4

URL: https://github.com/llvm/llvm-project/commit/65420c8041f4ca44a3a14c5f7faf426ee6a7c6a4
DIFF: https://github.com/llvm/llvm-project/commit/65420c8041f4ca44a3a14c5f7faf426ee6a7c6a4.diff

LOG: DAG: Use getNegatedExpression in combineMinNumMaxNum

Computing the negated RHS expression just to see if it compares equal
and throw it away feels dirty.

Added: 
    

Modified: 
    llvm/include/llvm/CodeGen/TargetLowering.h
    llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
    llvm/test/CodeGen/ARM/unsafe-fneg-select-minnum-maxnum-combine.ll

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index 21121d71a5fdd..8317a2e146a08 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -4070,22 +4070,34 @@ class TargetLowering : public TargetLoweringBase {
                                        NegatibleCost &Cost,
                                        unsigned Depth = 0) const;
 
-  /// This is the helper function to return the newly negated expression only
-  /// when the cost is cheaper.
-  SDValue getCheaperNegatedExpression(SDValue Op, SelectionDAG &DAG,
-                                      bool LegalOps, bool OptForSize,
-                                      unsigned Depth = 0) const {
+  SDValue getCheaperOrNeutralNegatedExpression(
+      SDValue Op, SelectionDAG &DAG, bool LegalOps, bool OptForSize,
+      const NegatibleCost CostThreshold = NegatibleCost::Neutral,
+      unsigned Depth = 0) const {
     NegatibleCost Cost = NegatibleCost::Expensive;
     SDValue Neg =
         getNegatedExpression(Op, DAG, LegalOps, OptForSize, Cost, Depth);
-    if (Neg && Cost == NegatibleCost::Cheaper)
+    if (!Neg)
+      return SDValue();
+
+    if (Cost <= CostThreshold)
       return Neg;
+
     // Remove the new created node to avoid the side effect to the DAG.
-    if (Neg && Neg->use_empty())
+    if (Neg->use_empty())
       DAG.RemoveDeadNode(Neg.getNode());
     return SDValue();
   }
 
+  /// This is the helper function to return the newly negated expression only
+  /// when the cost is cheaper.
+  SDValue getCheaperNegatedExpression(SDValue Op, SelectionDAG &DAG,
+                                      bool LegalOps, bool OptForSize,
+                                      unsigned Depth = 0) const {
+    return getCheaperOrNeutralNegatedExpression(Op, DAG, LegalOps, OptForSize,
+                                                NegatibleCost::Cheaper, Depth);
+  }
+
   /// This is the helper function to return the newly negated expression if
   /// the cost is not expensive.
   SDValue getNegatedExpression(SDValue Op, SelectionDAG &DAG, bool LegalOps,

diff  --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 8c87fc4acd3a6..73172bb5c1de1 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -387,6 +387,10 @@ namespace {
     SDValue PromoteExtend(SDValue Op);
     bool PromoteLoad(SDValue Op);
 
+    SDValue combineMinNumMaxNum(const SDLoc &DL, EVT VT, SDValue LHS,
+                                SDValue RHS, SDValue True, SDValue False,
+                                ISD::CondCode CC);
+
     /// Call the node-specific routine that knows how to fold each
     /// particular type of node. If that doesn't do anything, try the
     /// target-specific DAG combines.
@@ -10392,21 +10396,20 @@ static SDValue combineMinNumMaxNumImpl(const SDLoc &DL, EVT VT, SDValue LHS,
 }
 
 /// Generate Min/Max node
-static SDValue combineMinNumMaxNum(const SDLoc &DL, EVT VT, SDValue LHS,
-                                   SDValue RHS, SDValue True, SDValue False,
-                                   ISD::CondCode CC, const TargetLowering &TLI,
-                                   SelectionDAG &DAG) {
+SDValue DAGCombiner::combineMinNumMaxNum(const SDLoc &DL, EVT VT, SDValue LHS,
+                                         SDValue RHS, SDValue True,
+                                         SDValue False, ISD::CondCode CC) {
   if ((LHS == True && RHS == False) || (LHS == False && RHS == True))
     return combineMinNumMaxNumImpl(DL, VT, LHS, RHS, True, False, CC, TLI, DAG);
 
   // If we can't directly match this, try to see if we can pull an fneg out of
   // the select.
-  if (True.getOpcode() != ISD::FNEG)
+  SDValue NegTrue = TLI.getCheaperOrNeutralNegatedExpression(
+      True, DAG, LegalOperations, ForCodeSize);
+  if (!NegTrue)
     return SDValue();
 
-  ConstantFPSDNode *CRHS = dyn_cast<ConstantFPSDNode>(RHS);
-  ConstantFPSDNode *CFalse = dyn_cast<ConstantFPSDNode>(False);
-  SDValue NegTrue = True.getOperand(0);
+  HandleSDNode NegTrueHandle(NegTrue);
 
   // Try to unfold an fneg from the select if we are comparing the negated
   // constant.
@@ -10414,14 +10417,18 @@ static SDValue combineMinNumMaxNum(const SDLoc &DL, EVT VT, SDValue LHS,
   // select (setcc x, K) (fneg x), -K -> fneg(minnum(x, K))
   //
   // TODO: Handle fabs
-  if (LHS == NegTrue && CFalse && CRHS) {
-    APFloat NegRHS = neg(CRHS->getValueAPF());
-    if (NegRHS == CFalse->getValueAPF()) {
-      SDValue Combined = combineMinNumMaxNumImpl(DL, VT, LHS, RHS, NegTrue,
-                                                 False, CC, TLI, DAG);
-      if (Combined)
+  if (LHS == NegTrue) {
+    // If we can't directly match this, try to see if we can pull an fneg out of
+    // the select.
+    SDValue NegRHS = TLI.getCheaperOrNeutralNegatedExpression(
+        RHS, DAG, LegalOperations, ForCodeSize);
+    if (NegRHS) {
+      HandleSDNode NegRHSHandle(NegRHS);
+      if (NegRHS == False) {
+        SDValue Combined = combineMinNumMaxNumImpl(DL, VT, LHS, RHS, NegTrue,
+                                                   False, CC, TLI, DAG);
         return DAG.getNode(ISD::FNEG, DL, VT, Combined);
-      return SDValue();
+      }
     }
   }
 
@@ -10812,8 +10819,8 @@ SDValue DAGCombiner::visitSELECT(SDNode *N) {
     //
     // This is OK if we don't care what happens if either operand is a NaN.
     if (N0.hasOneUse() && isLegalToCombineMinNumMaxNum(DAG, N1, N2, TLI))
-      if (SDValue FMinMax = combineMinNumMaxNum(DL, VT, Cond0, Cond1, N1, N2,
-                                                CC, TLI, DAG))
+      if (SDValue FMinMax =
+              combineMinNumMaxNum(DL, VT, Cond0, Cond1, N1, N2, CC))
         return FMinMax;
 
     // Use 'unsigned add with overflow' to optimize an unsigned saturating add.
@@ -11325,8 +11332,7 @@ SDValue DAGCombiner::visitVSELECT(SDNode *N) {
     // NaN.
     //
     if (N0.hasOneUse() && isLegalToCombineMinNumMaxNum(DAG, LHS, RHS, TLI)) {
-      if (SDValue FMinMax =
-              combineMinNumMaxNum(DL, VT, LHS, RHS, N1, N2, CC, TLI, DAG))
+      if (SDValue FMinMax = combineMinNumMaxNum(DL, VT, LHS, RHS, N1, N2, CC))
         return FMinMax;
     }
 

diff  --git a/llvm/test/CodeGen/ARM/unsafe-fneg-select-minnum-maxnum-combine.ll b/llvm/test/CodeGen/ARM/unsafe-fneg-select-minnum-maxnum-combine.ll
index 23fdf07084705..664272ef8c098 100644
--- a/llvm/test/CodeGen/ARM/unsafe-fneg-select-minnum-maxnum-combine.ll
+++ b/llvm/test/CodeGen/ARM/unsafe-fneg-select-minnum-maxnum-combine.ll
@@ -67,13 +67,10 @@ define float @select_fsub0_or_8_cmp_olt_fsub1_neg8_f32(float %a, float %b) #0 {
 ; CHECK-NEXT:    vmov.f32 s0, #4.000000e+00
 ; CHECK-NEXT:    vmov s2, r0
 ; CHECK-NEXT:    vmov.f32 s4, #-8.000000e+00
-; CHECK-NEXT:    vmov.f32 s8, #8.000000e+00
-; CHECK-NEXT:    vsub.f32 s6, s0, s2
-; CHECK-NEXT:    vsub.f32 s0, s2, s0
-; CHECK-NEXT:    vcmp.f32 s4, s6
-; CHECK-NEXT:    vmrs APSR_nzcv, fpscr
-; CHECK-NEXT:    vselgt.f32 s0, s0, s8
+; CHECK-NEXT:    vsub.f32 s0, s0, s2
+; CHECK-NEXT:    vminnm.f32 s0, s0, s4
 ; CHECK-NEXT:    vmov r0, s0
+; CHECK-NEXT:    eor r0, r0, #-2147483648
 ; CHECK-NEXT:    mov pc, lr
   %sub.0 = fsub nnan nsz float 4.0, %a
   %sub.1 = fsub nnan nsz float %a, 4.0
@@ -88,13 +85,10 @@ define float @select_fsub0_or_neg8_cmp_olt_fsub1_8_f32(float %a, float %b) #0 {
 ; CHECK-NEXT:    vmov.f32 s0, #4.000000e+00
 ; CHECK-NEXT:    vmov s2, r0
 ; CHECK-NEXT:    vmov.f32 s4, #8.000000e+00
-; CHECK-NEXT:    vmov.f32 s8, #-8.000000e+00
-; CHECK-NEXT:    vsub.f32 s6, s0, s2
-; CHECK-NEXT:    vsub.f32 s0, s2, s0
-; CHECK-NEXT:    vcmp.f32 s4, s6
-; CHECK-NEXT:    vmrs APSR_nzcv, fpscr
-; CHECK-NEXT:    vselgt.f32 s0, s0, s8
+; CHECK-NEXT:    vsub.f32 s0, s0, s2
+; CHECK-NEXT:    vminnm.f32 s0, s0, s4
 ; CHECK-NEXT:    vmov r0, s0
+; CHECK-NEXT:    eor r0, r0, #-2147483648
 ; CHECK-NEXT:    mov pc, lr
   %sub.0 = fsub nnan nsz float 4.0, %a
   %sub.1 = fsub nnan nsz float %a, 4.0
@@ -108,15 +102,11 @@ define float @select_mul4_or_neg8_cmp_olt_mulneg4_8_f32(float %a, float %b) #0 {
 ; CHECK:       @ %bb.0:
 ; CHECK-NEXT:    vmov.f32 s0, #-4.000000e+00
 ; CHECK-NEXT:    vmov s2, r0
-; CHECK-NEXT:    vmov.f32 s6, #8.000000e+00
-; CHECK-NEXT:    vmov.f32 s4, #4.000000e+00
-; CHECK-NEXT:    vmov.f32 s8, #-8.000000e+00
+; CHECK-NEXT:    vmov.f32 s4, #8.000000e+00
 ; CHECK-NEXT:    vmul.f32 s0, s2, s0
-; CHECK-NEXT:    vmul.f32 s2, s2, s4
-; CHECK-NEXT:    vcmp.f32 s6, s0
-; CHECK-NEXT:    vmrs APSR_nzcv, fpscr
-; CHECK-NEXT:    vselgt.f32 s0, s2, s8
+; CHECK-NEXT:    vminnm.f32 s0, s0, s4
 ; CHECK-NEXT:    vmov r0, s0
+; CHECK-NEXT:    eor r0, r0, #-2147483648
 ; CHECK-NEXT:    mov pc, lr
   %mul.0 = fmul nnan nsz float %a, 4.0
   %mul.1 = fmul nnan nsz float %a, -4.0
@@ -130,15 +120,11 @@ define float @select_mul4_or_8_cmp_olt_mulneg4_neg8_f32(float %a, float %b) #0 {
 ; CHECK:       @ %bb.0:
 ; CHECK-NEXT:    vmov.f32 s0, #-4.000000e+00
 ; CHECK-NEXT:    vmov s2, r0
-; CHECK-NEXT:    vmov.f32 s6, #-8.000000e+00
-; CHECK-NEXT:    vmov.f32 s4, #4.000000e+00
-; CHECK-NEXT:    vmov.f32 s8, #8.000000e+00
+; CHECK-NEXT:    vmov.f32 s4, #-8.000000e+00
 ; CHECK-NEXT:    vmul.f32 s0, s2, s0
-; CHECK-NEXT:    vmul.f32 s2, s2, s4
-; CHECK-NEXT:    vcmp.f32 s6, s0
-; CHECK-NEXT:    vmrs APSR_nzcv, fpscr
-; CHECK-NEXT:    vselgt.f32 s0, s2, s8
+; CHECK-NEXT:    vminnm.f32 s0, s0, s4
 ; CHECK-NEXT:    vmov r0, s0
+; CHECK-NEXT:    eor r0, r0, #-2147483648
 ; CHECK-NEXT:    mov pc, lr
   %mul.0 = fmul nnan nsz float %a, 4.0
   %mul.1 = fmul nnan nsz float %a, -4.0
@@ -194,15 +180,11 @@ define float @select_mulneg4_or_neg8_cmp_olt_mul4_8_f32(float %a, float %b) #0 {
 ; CHECK:       @ %bb.0:
 ; CHECK-NEXT:    vmov.f32 s0, #4.000000e+00
 ; CHECK-NEXT:    vmov s2, r0
-; CHECK-NEXT:    vmov.f32 s6, #8.000000e+00
-; CHECK-NEXT:    vmov.f32 s4, #-4.000000e+00
-; CHECK-NEXT:    vmov.f32 s8, #-8.000000e+00
+; CHECK-NEXT:    vmov.f32 s4, #8.000000e+00
 ; CHECK-NEXT:    vmul.f32 s0, s2, s0
-; CHECK-NEXT:    vmul.f32 s2, s2, s4
-; CHECK-NEXT:    vcmp.f32 s6, s0
-; CHECK-NEXT:    vmrs APSR_nzcv, fpscr
-; CHECK-NEXT:    vselgt.f32 s0, s2, s8
+; CHECK-NEXT:    vminnm.f32 s0, s0, s4
 ; CHECK-NEXT:    vmov r0, s0
+; CHECK-NEXT:    eor r0, r0, #-2147483648
 ; CHECK-NEXT:    mov pc, lr
   %mul.0 = fmul nnan nsz float %a, -4.0
   %mul.1 = fmul nnan nsz float %a, 4.0


        


More information about the llvm-commits mailing list