[llvm] aca8b9d - [DAG] SimplifyDemandedBits - if we're only demanding the signbits, a MIN/MAX node can be simplified to a OR or AND node
Simon Pilgrim via llvm-commits
llvm-commits at lists.llvm.org
Fri Sep 1 02:59:33 PDT 2023
Author: Simon Pilgrim
Date: 2023-09-01T10:56:32+01:00
New Revision: aca8b9d0d56eea757b6903f75ff529f545ab229d
URL: https://github.com/llvm/llvm-project/commit/aca8b9d0d56eea757b6903f75ff529f545ab229d
DIFF: https://github.com/llvm/llvm-project/commit/aca8b9d0d56eea757b6903f75ff529f545ab229d.diff
LOG: [DAG] SimplifyDemandedBits - if we're only demanding the signbits, a MIN/MAX node can be simplified to a OR or AND node
Extension to the signbit case, if the signbits extend down through all the demanded bits then SMIN/SMAX/UMIN/UMAX nodes can be simplified to a OR/AND/AND/OR.
Alive2: https://alive2.llvm.org/ce/z/mFVFAn (general case)
Differential Revision: https://reviews.llvm.org/D158364
Added:
Modified:
llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
llvm/test/CodeGen/X86/known-signbits-vector.ll
Removed:
################################################################################
diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index 8d66c9f317e1ec..a8e15b2b1760f3 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -2156,54 +2156,49 @@ bool TargetLowering::SimplifyDemandedBits(
}
break;
}
- case ISD::SMIN: {
- SDValue Op0 = Op.getOperand(0);
- SDValue Op1 = Op.getOperand(1);
- // If we're only wanting the signbit, then we can simplify to OR node.
- // TODO: Extend this based on ComputeNumSignBits.
- if (DemandedBits.isSignMask())
- return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::OR, dl, VT, Op0, Op1));
- break;
- }
- case ISD::SMAX: {
- SDValue Op0 = Op.getOperand(0);
- SDValue Op1 = Op.getOperand(1);
- // If we're only wanting the signbit, then we can simplify to AND node.
- // TODO: Extend this based on ComputeNumSignBits.
- if (DemandedBits.isSignMask())
- return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::AND, dl, VT, Op0, Op1));
- break;
- }
- case ISD::UMIN: {
- SDValue Op0 = Op.getOperand(0);
- SDValue Op1 = Op.getOperand(1);
- // If we're only wanting the msb, then we can simplify to AND node.
- if (DemandedBits.isSignMask())
- return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::AND, dl, VT, Op0, Op1));
- // Check if one arg is always less than (or equal) to the other arg.
- KnownBits Known0 = TLO.DAG.computeKnownBits(Op0, DemandedElts, Depth + 1);
- KnownBits Known1 = TLO.DAG.computeKnownBits(Op1, DemandedElts, Depth + 1);
- Known = KnownBits::umin(Known0, Known1);
- if (std::optional<bool> IsULE = KnownBits::ule(Known0, Known1))
- return TLO.CombineTo(Op, *IsULE ? Op0 : Op1);
- if (std::optional<bool> IsULT = KnownBits::ult(Known0, Known1))
- return TLO.CombineTo(Op, *IsULT ? Op0 : Op1);
- break;
- }
+ case ISD::SMIN:
+ case ISD::SMAX:
+ case ISD::UMIN:
case ISD::UMAX: {
+ unsigned Opc = Op.getOpcode();
SDValue Op0 = Op.getOperand(0);
SDValue Op1 = Op.getOperand(1);
- // If we're only wanting the msb, then we can simplify to OR node.
- if (DemandedBits.isSignMask())
- return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::OR, dl, VT, Op0, Op1));
- // Check if one arg is always greater than (or equal) to the other arg.
+
+ // If we're only demanding signbits, then we can simplify to OR/AND node.
+ unsigned BitOp =
+ (Opc == ISD::SMIN || Opc == ISD::UMAX) ? ISD::OR : ISD::AND;
+ unsigned NumSignBits =
+ std::min(TLO.DAG.ComputeNumSignBits(Op0, DemandedElts, Depth + 1),
+ TLO.DAG.ComputeNumSignBits(Op1, DemandedElts, Depth + 1));
+ unsigned NumDemandedUpperBits = BitWidth - DemandedBits.countr_zero();
+ if (NumSignBits >= NumDemandedUpperBits)
+ return TLO.CombineTo(Op, TLO.DAG.getNode(BitOp, SDLoc(Op), VT, Op0, Op1));
+
+ // Check if one arg is always less/greater than (or equal) to the other arg.
KnownBits Known0 = TLO.DAG.computeKnownBits(Op0, DemandedElts, Depth + 1);
KnownBits Known1 = TLO.DAG.computeKnownBits(Op1, DemandedElts, Depth + 1);
- Known = KnownBits::umax(Known0, Known1);
- if (std::optional<bool> IsUGE = KnownBits::uge(Known0, Known1))
- return TLO.CombineTo(Op, *IsUGE ? Op0 : Op1);
- if (std::optional<bool> IsUGT = KnownBits::ugt(Known0, Known1))
- return TLO.CombineTo(Op, *IsUGT ? Op0 : Op1);
+ switch (Opc) {
+ case ISD::SMIN:
+ // TODO: Add KnownBits::sle/slt handling.
+ break;
+ case ISD::SMAX:
+ // TODO: Add KnownBits::sge/sgt handling.
+ break;
+ case ISD::UMIN:
+ if (std::optional<bool> IsULE = KnownBits::ule(Known0, Known1))
+ return TLO.CombineTo(Op, *IsULE ? Op0 : Op1);
+ if (std::optional<bool> IsULT = KnownBits::ult(Known0, Known1))
+ return TLO.CombineTo(Op, *IsULT ? Op0 : Op1);
+ Known = KnownBits::umin(Known0, Known1);
+ break;
+ case ISD::UMAX:
+ if (std::optional<bool> IsUGE = KnownBits::uge(Known0, Known1))
+ return TLO.CombineTo(Op, *IsUGE ? Op0 : Op1);
+ if (std::optional<bool> IsUGT = KnownBits::ugt(Known0, Known1))
+ return TLO.CombineTo(Op, *IsUGT ? Op0 : Op1);
+ Known = KnownBits::umax(Known0, Known1);
+ break;
+ }
break;
}
case ISD::BITREVERSE: {
diff --git a/llvm/test/CodeGen/X86/known-signbits-vector.ll b/llvm/test/CodeGen/X86/known-signbits-vector.ll
index de7186584e67a3..e500801b69c4d0 100644
--- a/llvm/test/CodeGen/X86/known-signbits-vector.ll
+++ b/llvm/test/CodeGen/X86/known-signbits-vector.ll
@@ -483,28 +483,24 @@ define <4 x float> @signbits_ashr_sext_select_shuffle_sitofp(<4 x i64> %a0, <4 x
define <4 x i32> @signbits_mask_ashr_smax(<4 x i32> %a0, <4 x i32> %a1) {
; X86-LABEL: signbits_mask_ashr_smax:
; X86: # %bb.0:
-; X86-NEXT: vpsrad $25, %xmm0, %xmm0
-; X86-NEXT: vpsrad $25, %xmm1, %xmm1
-; X86-NEXT: vpmaxsd %xmm1, %xmm0, %xmm0
+; X86-NEXT: vpand %xmm1, %xmm0, %xmm0
; X86-NEXT: vpshufd {{.*#+}} xmm0 = xmm0[0,0,0,0]
+; X86-NEXT: vpsrad $25, %xmm0, %xmm0
; X86-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}, %xmm0, %xmm0
; X86-NEXT: retl
;
; X64-AVX1-LABEL: signbits_mask_ashr_smax:
; X64-AVX1: # %bb.0:
-; X64-AVX1-NEXT: vpsrad $25, %xmm0, %xmm0
-; X64-AVX1-NEXT: vpsrad $25, %xmm1, %xmm1
-; X64-AVX1-NEXT: vpmaxsd %xmm1, %xmm0, %xmm0
+; X64-AVX1-NEXT: vpand %xmm1, %xmm0, %xmm0
; X64-AVX1-NEXT: vpshufd {{.*#+}} xmm0 = xmm0[0,0,0,0]
+; X64-AVX1-NEXT: vpsrad $25, %xmm0, %xmm0
; X64-AVX1-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
; X64-AVX1-NEXT: retq
;
; X64-AVX2-LABEL: signbits_mask_ashr_smax:
; X64-AVX2: # %bb.0:
-; X64-AVX2-NEXT: vmovdqa {{.*#+}} xmm2 = [25,26,27,0]
-; X64-AVX2-NEXT: vpsravd %xmm2, %xmm0, %xmm0
-; X64-AVX2-NEXT: vpsravd %xmm2, %xmm1, %xmm1
-; X64-AVX2-NEXT: vpmaxsd %xmm1, %xmm0, %xmm0
+; X64-AVX2-NEXT: vpand %xmm1, %xmm0, %xmm0
+; X64-AVX2-NEXT: vpsrad $25, %xmm0, %xmm0
; X64-AVX2-NEXT: vpbroadcastd %xmm0, %xmm0
; X64-AVX2-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
; X64-AVX2-NEXT: retq
@@ -521,28 +517,24 @@ declare <4 x i32> @llvm.smax.v4i32(<4 x i32>, <4 x i32>) nounwind readnone
define <4 x i32> @signbits_mask_ashr_smin(<4 x i32> %a0, <4 x i32> %a1) {
; X86-LABEL: signbits_mask_ashr_smin:
; X86: # %bb.0:
-; X86-NEXT: vpsrad $25, %xmm0, %xmm0
-; X86-NEXT: vpsrad $25, %xmm1, %xmm1
-; X86-NEXT: vpminsd %xmm1, %xmm0, %xmm0
+; X86-NEXT: vpor %xmm1, %xmm0, %xmm0
; X86-NEXT: vpshufd {{.*#+}} xmm0 = xmm0[0,0,0,0]
+; X86-NEXT: vpsrad $25, %xmm0, %xmm0
; X86-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}, %xmm0, %xmm0
; X86-NEXT: retl
;
; X64-AVX1-LABEL: signbits_mask_ashr_smin:
; X64-AVX1: # %bb.0:
-; X64-AVX1-NEXT: vpsrad $25, %xmm0, %xmm0
-; X64-AVX1-NEXT: vpsrad $25, %xmm1, %xmm1
-; X64-AVX1-NEXT: vpminsd %xmm1, %xmm0, %xmm0
+; X64-AVX1-NEXT: vpor %xmm1, %xmm0, %xmm0
; X64-AVX1-NEXT: vpshufd {{.*#+}} xmm0 = xmm0[0,0,0,0]
+; X64-AVX1-NEXT: vpsrad $25, %xmm0, %xmm0
; X64-AVX1-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
; X64-AVX1-NEXT: retq
;
; X64-AVX2-LABEL: signbits_mask_ashr_smin:
; X64-AVX2: # %bb.0:
-; X64-AVX2-NEXT: vmovdqa {{.*#+}} xmm2 = [25,26,27,0]
-; X64-AVX2-NEXT: vpsravd %xmm2, %xmm0, %xmm0
-; X64-AVX2-NEXT: vpsravd %xmm2, %xmm1, %xmm1
-; X64-AVX2-NEXT: vpminsd %xmm1, %xmm0, %xmm0
+; X64-AVX2-NEXT: vpor %xmm1, %xmm0, %xmm0
+; X64-AVX2-NEXT: vpsrad $25, %xmm0, %xmm0
; X64-AVX2-NEXT: vpbroadcastd %xmm0, %xmm0
; X64-AVX2-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
; X64-AVX2-NEXT: retq
@@ -559,28 +551,24 @@ declare <4 x i32> @llvm.smin.v4i32(<4 x i32>, <4 x i32>) nounwind readnone
define <4 x i32> @signbits_mask_ashr_umax(<4 x i32> %a0, <4 x i32> %a1) {
; X86-LABEL: signbits_mask_ashr_umax:
; X86: # %bb.0:
-; X86-NEXT: vpsrad $25, %xmm0, %xmm0
-; X86-NEXT: vpsrad $25, %xmm1, %xmm1
-; X86-NEXT: vpmaxud %xmm1, %xmm0, %xmm0
+; X86-NEXT: vpor %xmm1, %xmm0, %xmm0
; X86-NEXT: vpshufd {{.*#+}} xmm0 = xmm0[0,0,0,0]
+; X86-NEXT: vpsrad $25, %xmm0, %xmm0
; X86-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}, %xmm0, %xmm0
; X86-NEXT: retl
;
; X64-AVX1-LABEL: signbits_mask_ashr_umax:
; X64-AVX1: # %bb.0:
-; X64-AVX1-NEXT: vpsrad $25, %xmm0, %xmm0
-; X64-AVX1-NEXT: vpsrad $25, %xmm1, %xmm1
-; X64-AVX1-NEXT: vpmaxud %xmm1, %xmm0, %xmm0
+; X64-AVX1-NEXT: vpor %xmm1, %xmm0, %xmm0
; X64-AVX1-NEXT: vpshufd {{.*#+}} xmm0 = xmm0[0,0,0,0]
+; X64-AVX1-NEXT: vpsrad $25, %xmm0, %xmm0
; X64-AVX1-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
; X64-AVX1-NEXT: retq
;
; X64-AVX2-LABEL: signbits_mask_ashr_umax:
; X64-AVX2: # %bb.0:
-; X64-AVX2-NEXT: vmovdqa {{.*#+}} xmm2 = [25,26,27,0]
-; X64-AVX2-NEXT: vpsravd %xmm2, %xmm0, %xmm0
-; X64-AVX2-NEXT: vpsravd %xmm2, %xmm1, %xmm1
-; X64-AVX2-NEXT: vpmaxud %xmm1, %xmm0, %xmm0
+; X64-AVX2-NEXT: vpor %xmm1, %xmm0, %xmm0
+; X64-AVX2-NEXT: vpsrad $25, %xmm0, %xmm0
; X64-AVX2-NEXT: vpbroadcastd %xmm0, %xmm0
; X64-AVX2-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
; X64-AVX2-NEXT: retq
@@ -597,28 +585,24 @@ declare <4 x i32> @llvm.umax.v4i32(<4 x i32>, <4 x i32>) nounwind readnone
define <4 x i32> @signbits_mask_ashr_umin(<4 x i32> %a0, <4 x i32> %a1) {
; X86-LABEL: signbits_mask_ashr_umin:
; X86: # %bb.0:
-; X86-NEXT: vpsrad $25, %xmm0, %xmm0
-; X86-NEXT: vpsrad $25, %xmm1, %xmm1
-; X86-NEXT: vpminud %xmm1, %xmm0, %xmm0
+; X86-NEXT: vpand %xmm1, %xmm0, %xmm0
; X86-NEXT: vpshufd {{.*#+}} xmm0 = xmm0[0,0,0,0]
+; X86-NEXT: vpsrad $25, %xmm0, %xmm0
; X86-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}, %xmm0, %xmm0
; X86-NEXT: retl
;
; X64-AVX1-LABEL: signbits_mask_ashr_umin:
; X64-AVX1: # %bb.0:
-; X64-AVX1-NEXT: vpsrad $25, %xmm0, %xmm0
-; X64-AVX1-NEXT: vpsrad $25, %xmm1, %xmm1
-; X64-AVX1-NEXT: vpminud %xmm1, %xmm0, %xmm0
+; X64-AVX1-NEXT: vpand %xmm1, %xmm0, %xmm0
; X64-AVX1-NEXT: vpshufd {{.*#+}} xmm0 = xmm0[0,0,0,0]
+; X64-AVX1-NEXT: vpsrad $25, %xmm0, %xmm0
; X64-AVX1-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
; X64-AVX1-NEXT: retq
;
; X64-AVX2-LABEL: signbits_mask_ashr_umin:
; X64-AVX2: # %bb.0:
-; X64-AVX2-NEXT: vmovdqa {{.*#+}} xmm2 = [25,26,27,0]
-; X64-AVX2-NEXT: vpsravd %xmm2, %xmm0, %xmm0
-; X64-AVX2-NEXT: vpsravd %xmm2, %xmm1, %xmm1
-; X64-AVX2-NEXT: vpminud %xmm1, %xmm0, %xmm0
+; X64-AVX2-NEXT: vpand %xmm1, %xmm0, %xmm0
+; X64-AVX2-NEXT: vpsrad $25, %xmm0, %xmm0
; X64-AVX2-NEXT: vpbroadcastd %xmm0, %xmm0
; X64-AVX2-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
; X64-AVX2-NEXT: retq
More information about the llvm-commits
mailing list