[llvm] aca8b9d - [DAG] SimplifyDemandedBits - if we're only demanding the signbits, a MIN/MAX node can be simplified to a OR or AND node

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Fri Sep 1 02:59:33 PDT 2023


Author: Simon Pilgrim
Date: 2023-09-01T10:56:32+01:00
New Revision: aca8b9d0d56eea757b6903f75ff529f545ab229d

URL: https://github.com/llvm/llvm-project/commit/aca8b9d0d56eea757b6903f75ff529f545ab229d
DIFF: https://github.com/llvm/llvm-project/commit/aca8b9d0d56eea757b6903f75ff529f545ab229d.diff

LOG: [DAG] SimplifyDemandedBits - if we're only demanding the signbits, a MIN/MAX node can be simplified to a OR or AND node

Extension to the signbit case, if the signbits extend down through all the demanded bits then SMIN/SMAX/UMIN/UMAX nodes can be simplified to a OR/AND/AND/OR.

Alive2: https://alive2.llvm.org/ce/z/mFVFAn (general case)

Differential Revision: https://reviews.llvm.org/D158364

Added: 
    

Modified: 
    llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
    llvm/test/CodeGen/X86/known-signbits-vector.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index 8d66c9f317e1ec..a8e15b2b1760f3 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -2156,54 +2156,49 @@ bool TargetLowering::SimplifyDemandedBits(
     }
     break;
   }
-  case ISD::SMIN: {
-    SDValue Op0 = Op.getOperand(0);
-    SDValue Op1 = Op.getOperand(1);
-    // If we're only wanting the signbit, then we can simplify to OR node.
-    // TODO: Extend this based on ComputeNumSignBits.
-    if (DemandedBits.isSignMask())
-      return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::OR, dl, VT, Op0, Op1));
-    break;
-  }
-  case ISD::SMAX: {
-    SDValue Op0 = Op.getOperand(0);
-    SDValue Op1 = Op.getOperand(1);
-    // If we're only wanting the signbit, then we can simplify to AND node.
-    // TODO: Extend this based on ComputeNumSignBits.
-    if (DemandedBits.isSignMask())
-      return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::AND, dl, VT, Op0, Op1));
-    break;
-  }
-  case ISD::UMIN: {
-    SDValue Op0 = Op.getOperand(0);
-    SDValue Op1 = Op.getOperand(1);
-    // If we're only wanting the msb, then we can simplify to AND node.
-    if (DemandedBits.isSignMask())
-      return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::AND, dl, VT, Op0, Op1));
-    // Check if one arg is always less than (or equal) to the other arg.
-    KnownBits Known0 = TLO.DAG.computeKnownBits(Op0, DemandedElts, Depth + 1);
-    KnownBits Known1 = TLO.DAG.computeKnownBits(Op1, DemandedElts, Depth + 1);
-    Known = KnownBits::umin(Known0, Known1);
-    if (std::optional<bool> IsULE = KnownBits::ule(Known0, Known1))
-      return TLO.CombineTo(Op, *IsULE ? Op0 : Op1);
-    if (std::optional<bool> IsULT = KnownBits::ult(Known0, Known1))
-      return TLO.CombineTo(Op, *IsULT ? Op0 : Op1);
-    break;
-  }
+  case ISD::SMIN:
+  case ISD::SMAX:
+  case ISD::UMIN:
   case ISD::UMAX: {
+    unsigned Opc = Op.getOpcode();
     SDValue Op0 = Op.getOperand(0);
     SDValue Op1 = Op.getOperand(1);
-    // If we're only wanting the msb, then we can simplify to OR node.
-    if (DemandedBits.isSignMask())
-      return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::OR, dl, VT, Op0, Op1));
-    // Check if one arg is always greater than (or equal) to the other arg.
+
+    // If we're only demanding signbits, then we can simplify to OR/AND node.
+    unsigned BitOp =
+        (Opc == ISD::SMIN || Opc == ISD::UMAX) ? ISD::OR : ISD::AND;
+    unsigned NumSignBits =
+        std::min(TLO.DAG.ComputeNumSignBits(Op0, DemandedElts, Depth + 1),
+                 TLO.DAG.ComputeNumSignBits(Op1, DemandedElts, Depth + 1));
+    unsigned NumDemandedUpperBits = BitWidth - DemandedBits.countr_zero();
+    if (NumSignBits >= NumDemandedUpperBits)
+      return TLO.CombineTo(Op, TLO.DAG.getNode(BitOp, SDLoc(Op), VT, Op0, Op1));
+
+    // Check if one arg is always less/greater than (or equal) to the other arg.
     KnownBits Known0 = TLO.DAG.computeKnownBits(Op0, DemandedElts, Depth + 1);
     KnownBits Known1 = TLO.DAG.computeKnownBits(Op1, DemandedElts, Depth + 1);
-    Known = KnownBits::umax(Known0, Known1);
-    if (std::optional<bool> IsUGE = KnownBits::uge(Known0, Known1))
-      return TLO.CombineTo(Op, *IsUGE ? Op0 : Op1);
-    if (std::optional<bool> IsUGT = KnownBits::ugt(Known0, Known1))
-      return TLO.CombineTo(Op, *IsUGT ? Op0 : Op1);
+    switch (Opc) {
+    case ISD::SMIN:
+      // TODO: Add KnownBits::sle/slt handling.
+      break;
+    case ISD::SMAX:
+      // TODO: Add KnownBits::sge/sgt handling.
+      break;
+    case ISD::UMIN:
+      if (std::optional<bool> IsULE = KnownBits::ule(Known0, Known1))
+        return TLO.CombineTo(Op, *IsULE ? Op0 : Op1);
+      if (std::optional<bool> IsULT = KnownBits::ult(Known0, Known1))
+        return TLO.CombineTo(Op, *IsULT ? Op0 : Op1);
+      Known = KnownBits::umin(Known0, Known1);
+      break;
+    case ISD::UMAX:
+      if (std::optional<bool> IsUGE = KnownBits::uge(Known0, Known1))
+        return TLO.CombineTo(Op, *IsUGE ? Op0 : Op1);
+      if (std::optional<bool> IsUGT = KnownBits::ugt(Known0, Known1))
+        return TLO.CombineTo(Op, *IsUGT ? Op0 : Op1);
+      Known = KnownBits::umax(Known0, Known1);
+      break;
+    }
     break;
   }
   case ISD::BITREVERSE: {

diff  --git a/llvm/test/CodeGen/X86/known-signbits-vector.ll b/llvm/test/CodeGen/X86/known-signbits-vector.ll
index de7186584e67a3..e500801b69c4d0 100644
--- a/llvm/test/CodeGen/X86/known-signbits-vector.ll
+++ b/llvm/test/CodeGen/X86/known-signbits-vector.ll
@@ -483,28 +483,24 @@ define <4 x float> @signbits_ashr_sext_select_shuffle_sitofp(<4 x i64> %a0, <4 x
 define <4 x i32> @signbits_mask_ashr_smax(<4 x i32> %a0, <4 x i32> %a1) {
 ; X86-LABEL: signbits_mask_ashr_smax:
 ; X86:       # %bb.0:
-; X86-NEXT:    vpsrad $25, %xmm0, %xmm0
-; X86-NEXT:    vpsrad $25, %xmm1, %xmm1
-; X86-NEXT:    vpmaxsd %xmm1, %xmm0, %xmm0
+; X86-NEXT:    vpand %xmm1, %xmm0, %xmm0
 ; X86-NEXT:    vpshufd {{.*#+}} xmm0 = xmm0[0,0,0,0]
+; X86-NEXT:    vpsrad $25, %xmm0, %xmm0
 ; X86-NEXT:    vpand {{\.?LCPI[0-9]+_[0-9]+}}, %xmm0, %xmm0
 ; X86-NEXT:    retl
 ;
 ; X64-AVX1-LABEL: signbits_mask_ashr_smax:
 ; X64-AVX1:       # %bb.0:
-; X64-AVX1-NEXT:    vpsrad $25, %xmm0, %xmm0
-; X64-AVX1-NEXT:    vpsrad $25, %xmm1, %xmm1
-; X64-AVX1-NEXT:    vpmaxsd %xmm1, %xmm0, %xmm0
+; X64-AVX1-NEXT:    vpand %xmm1, %xmm0, %xmm0
 ; X64-AVX1-NEXT:    vpshufd {{.*#+}} xmm0 = xmm0[0,0,0,0]
+; X64-AVX1-NEXT:    vpsrad $25, %xmm0, %xmm0
 ; X64-AVX1-NEXT:    vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
 ; X64-AVX1-NEXT:    retq
 ;
 ; X64-AVX2-LABEL: signbits_mask_ashr_smax:
 ; X64-AVX2:       # %bb.0:
-; X64-AVX2-NEXT:    vmovdqa {{.*#+}} xmm2 = [25,26,27,0]
-; X64-AVX2-NEXT:    vpsravd %xmm2, %xmm0, %xmm0
-; X64-AVX2-NEXT:    vpsravd %xmm2, %xmm1, %xmm1
-; X64-AVX2-NEXT:    vpmaxsd %xmm1, %xmm0, %xmm0
+; X64-AVX2-NEXT:    vpand %xmm1, %xmm0, %xmm0
+; X64-AVX2-NEXT:    vpsrad $25, %xmm0, %xmm0
 ; X64-AVX2-NEXT:    vpbroadcastd %xmm0, %xmm0
 ; X64-AVX2-NEXT:    vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
 ; X64-AVX2-NEXT:    retq
@@ -521,28 +517,24 @@ declare <4 x i32> @llvm.smax.v4i32(<4 x i32>, <4 x i32>) nounwind readnone
 define <4 x i32> @signbits_mask_ashr_smin(<4 x i32> %a0, <4 x i32> %a1) {
 ; X86-LABEL: signbits_mask_ashr_smin:
 ; X86:       # %bb.0:
-; X86-NEXT:    vpsrad $25, %xmm0, %xmm0
-; X86-NEXT:    vpsrad $25, %xmm1, %xmm1
-; X86-NEXT:    vpminsd %xmm1, %xmm0, %xmm0
+; X86-NEXT:    vpor %xmm1, %xmm0, %xmm0
 ; X86-NEXT:    vpshufd {{.*#+}} xmm0 = xmm0[0,0,0,0]
+; X86-NEXT:    vpsrad $25, %xmm0, %xmm0
 ; X86-NEXT:    vpand {{\.?LCPI[0-9]+_[0-9]+}}, %xmm0, %xmm0
 ; X86-NEXT:    retl
 ;
 ; X64-AVX1-LABEL: signbits_mask_ashr_smin:
 ; X64-AVX1:       # %bb.0:
-; X64-AVX1-NEXT:    vpsrad $25, %xmm0, %xmm0
-; X64-AVX1-NEXT:    vpsrad $25, %xmm1, %xmm1
-; X64-AVX1-NEXT:    vpminsd %xmm1, %xmm0, %xmm0
+; X64-AVX1-NEXT:    vpor %xmm1, %xmm0, %xmm0
 ; X64-AVX1-NEXT:    vpshufd {{.*#+}} xmm0 = xmm0[0,0,0,0]
+; X64-AVX1-NEXT:    vpsrad $25, %xmm0, %xmm0
 ; X64-AVX1-NEXT:    vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
 ; X64-AVX1-NEXT:    retq
 ;
 ; X64-AVX2-LABEL: signbits_mask_ashr_smin:
 ; X64-AVX2:       # %bb.0:
-; X64-AVX2-NEXT:    vmovdqa {{.*#+}} xmm2 = [25,26,27,0]
-; X64-AVX2-NEXT:    vpsravd %xmm2, %xmm0, %xmm0
-; X64-AVX2-NEXT:    vpsravd %xmm2, %xmm1, %xmm1
-; X64-AVX2-NEXT:    vpminsd %xmm1, %xmm0, %xmm0
+; X64-AVX2-NEXT:    vpor %xmm1, %xmm0, %xmm0
+; X64-AVX2-NEXT:    vpsrad $25, %xmm0, %xmm0
 ; X64-AVX2-NEXT:    vpbroadcastd %xmm0, %xmm0
 ; X64-AVX2-NEXT:    vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
 ; X64-AVX2-NEXT:    retq
@@ -559,28 +551,24 @@ declare <4 x i32> @llvm.smin.v4i32(<4 x i32>, <4 x i32>) nounwind readnone
 define <4 x i32> @signbits_mask_ashr_umax(<4 x i32> %a0, <4 x i32> %a1) {
 ; X86-LABEL: signbits_mask_ashr_umax:
 ; X86:       # %bb.0:
-; X86-NEXT:    vpsrad $25, %xmm0, %xmm0
-; X86-NEXT:    vpsrad $25, %xmm1, %xmm1
-; X86-NEXT:    vpmaxud %xmm1, %xmm0, %xmm0
+; X86-NEXT:    vpor %xmm1, %xmm0, %xmm0
 ; X86-NEXT:    vpshufd {{.*#+}} xmm0 = xmm0[0,0,0,0]
+; X86-NEXT:    vpsrad $25, %xmm0, %xmm0
 ; X86-NEXT:    vpand {{\.?LCPI[0-9]+_[0-9]+}}, %xmm0, %xmm0
 ; X86-NEXT:    retl
 ;
 ; X64-AVX1-LABEL: signbits_mask_ashr_umax:
 ; X64-AVX1:       # %bb.0:
-; X64-AVX1-NEXT:    vpsrad $25, %xmm0, %xmm0
-; X64-AVX1-NEXT:    vpsrad $25, %xmm1, %xmm1
-; X64-AVX1-NEXT:    vpmaxud %xmm1, %xmm0, %xmm0
+; X64-AVX1-NEXT:    vpor %xmm1, %xmm0, %xmm0
 ; X64-AVX1-NEXT:    vpshufd {{.*#+}} xmm0 = xmm0[0,0,0,0]
+; X64-AVX1-NEXT:    vpsrad $25, %xmm0, %xmm0
 ; X64-AVX1-NEXT:    vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
 ; X64-AVX1-NEXT:    retq
 ;
 ; X64-AVX2-LABEL: signbits_mask_ashr_umax:
 ; X64-AVX2:       # %bb.0:
-; X64-AVX2-NEXT:    vmovdqa {{.*#+}} xmm2 = [25,26,27,0]
-; X64-AVX2-NEXT:    vpsravd %xmm2, %xmm0, %xmm0
-; X64-AVX2-NEXT:    vpsravd %xmm2, %xmm1, %xmm1
-; X64-AVX2-NEXT:    vpmaxud %xmm1, %xmm0, %xmm0
+; X64-AVX2-NEXT:    vpor %xmm1, %xmm0, %xmm0
+; X64-AVX2-NEXT:    vpsrad $25, %xmm0, %xmm0
 ; X64-AVX2-NEXT:    vpbroadcastd %xmm0, %xmm0
 ; X64-AVX2-NEXT:    vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
 ; X64-AVX2-NEXT:    retq
@@ -597,28 +585,24 @@ declare <4 x i32> @llvm.umax.v4i32(<4 x i32>, <4 x i32>) nounwind readnone
 define <4 x i32> @signbits_mask_ashr_umin(<4 x i32> %a0, <4 x i32> %a1) {
 ; X86-LABEL: signbits_mask_ashr_umin:
 ; X86:       # %bb.0:
-; X86-NEXT:    vpsrad $25, %xmm0, %xmm0
-; X86-NEXT:    vpsrad $25, %xmm1, %xmm1
-; X86-NEXT:    vpminud %xmm1, %xmm0, %xmm0
+; X86-NEXT:    vpand %xmm1, %xmm0, %xmm0
 ; X86-NEXT:    vpshufd {{.*#+}} xmm0 = xmm0[0,0,0,0]
+; X86-NEXT:    vpsrad $25, %xmm0, %xmm0
 ; X86-NEXT:    vpand {{\.?LCPI[0-9]+_[0-9]+}}, %xmm0, %xmm0
 ; X86-NEXT:    retl
 ;
 ; X64-AVX1-LABEL: signbits_mask_ashr_umin:
 ; X64-AVX1:       # %bb.0:
-; X64-AVX1-NEXT:    vpsrad $25, %xmm0, %xmm0
-; X64-AVX1-NEXT:    vpsrad $25, %xmm1, %xmm1
-; X64-AVX1-NEXT:    vpminud %xmm1, %xmm0, %xmm0
+; X64-AVX1-NEXT:    vpand %xmm1, %xmm0, %xmm0
 ; X64-AVX1-NEXT:    vpshufd {{.*#+}} xmm0 = xmm0[0,0,0,0]
+; X64-AVX1-NEXT:    vpsrad $25, %xmm0, %xmm0
 ; X64-AVX1-NEXT:    vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
 ; X64-AVX1-NEXT:    retq
 ;
 ; X64-AVX2-LABEL: signbits_mask_ashr_umin:
 ; X64-AVX2:       # %bb.0:
-; X64-AVX2-NEXT:    vmovdqa {{.*#+}} xmm2 = [25,26,27,0]
-; X64-AVX2-NEXT:    vpsravd %xmm2, %xmm0, %xmm0
-; X64-AVX2-NEXT:    vpsravd %xmm2, %xmm1, %xmm1
-; X64-AVX2-NEXT:    vpminud %xmm1, %xmm0, %xmm0
+; X64-AVX2-NEXT:    vpand %xmm1, %xmm0, %xmm0
+; X64-AVX2-NEXT:    vpsrad $25, %xmm0, %xmm0
 ; X64-AVX2-NEXT:    vpbroadcastd %xmm0, %xmm0
 ; X64-AVX2-NEXT:    vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
 ; X64-AVX2-NEXT:    retq


        


More information about the llvm-commits mailing list