[llvm] [SelectionDAG] Return std::optional<unsigned> from getValidShiftAmount and friends. NFC (PR #156224)
via llvm-commits
llvm-commits at lists.llvm.org
Sat Aug 30 21:29:36 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-x86
Author: Craig Topper (topperc)
<details>
<summary>Changes</summary>
Instead of std::optional<uint64_t>. Shift amounts must be less than or equal to our maximum supported bit widths which fit in unsigned. Most of the callers already assumed it fit in unsigned.
---
Full diff: https://github.com/llvm/llvm-project/pull/156224.diff
4 Files Affected:
- (modified) llvm/include/llvm/CodeGen/SelectionDAG.h (+6-6)
- (modified) llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp (+13-13)
- (modified) llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp (+14-14)
- (modified) llvm/lib/Target/X86/X86ISelLowering.cpp (+4-4)
``````````diff
diff --git a/llvm/include/llvm/CodeGen/SelectionDAG.h b/llvm/include/llvm/CodeGen/SelectionDAG.h
index dc00db9daa3b6..8a834315646a1 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAG.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAG.h
@@ -2350,35 +2350,35 @@ class SelectionDAG {
/// If a SHL/SRA/SRL node \p V has a uniform shift amount
/// that is less than the element bit-width of the shift node, return it.
- LLVM_ABI std::optional<uint64_t>
+ LLVM_ABI std::optional<unsigned>
getValidShiftAmount(SDValue V, const APInt &DemandedElts,
unsigned Depth = 0) const;
/// If a SHL/SRA/SRL node \p V has a uniform shift amount
/// that is less than the element bit-width of the shift node, return it.
- LLVM_ABI std::optional<uint64_t>
+ LLVM_ABI std::optional<unsigned>
getValidShiftAmount(SDValue V, unsigned Depth = 0) const;
/// If a SHL/SRA/SRL node \p V has shift amounts that are all less than the
/// element bit-width of the shift node, return the minimum possible value.
- LLVM_ABI std::optional<uint64_t>
+ LLVM_ABI std::optional<unsigned>
getValidMinimumShiftAmount(SDValue V, const APInt &DemandedElts,
unsigned Depth = 0) const;
/// If a SHL/SRA/SRL node \p V has shift amounts that are all less than the
/// element bit-width of the shift node, return the minimum possible value.
- LLVM_ABI std::optional<uint64_t>
+ LLVM_ABI std::optional<unsigned>
getValidMinimumShiftAmount(SDValue V, unsigned Depth = 0) const;
/// If a SHL/SRA/SRL node \p V has shift amounts that are all less than the
/// element bit-width of the shift node, return the maximum possible value.
- LLVM_ABI std::optional<uint64_t>
+ LLVM_ABI std::optional<unsigned>
getValidMaximumShiftAmount(SDValue V, const APInt &DemandedElts,
unsigned Depth = 0) const;
/// If a SHL/SRA/SRL node \p V has shift amounts that are all less than the
/// element bit-width of the shift node, return the maximum possible value.
- LLVM_ABI std::optional<uint64_t>
+ LLVM_ABI std::optional<unsigned>
getValidMaximumShiftAmount(SDValue V, unsigned Depth = 0) const;
/// Match a binop + shuffle pyramid that represents a horizontal reduction
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 967306ae37f45..56f914907085b 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -3299,7 +3299,7 @@ SelectionDAG::getValidShiftAmountRange(SDValue V, const APInt &DemandedElts,
return std::nullopt;
}
-std::optional<uint64_t>
+std::optional<unsigned>
SelectionDAG::getValidShiftAmount(SDValue V, const APInt &DemandedElts,
unsigned Depth) const {
assert((V.getOpcode() == ISD::SHL || V.getOpcode() == ISD::SRL ||
@@ -3312,7 +3312,7 @@ SelectionDAG::getValidShiftAmount(SDValue V, const APInt &DemandedElts,
return std::nullopt;
}
-std::optional<uint64_t>
+std::optional<unsigned>
SelectionDAG::getValidShiftAmount(SDValue V, unsigned Depth) const {
EVT VT = V.getValueType();
APInt DemandedElts = VT.isFixedLengthVector()
@@ -3321,7 +3321,7 @@ SelectionDAG::getValidShiftAmount(SDValue V, unsigned Depth) const {
return getValidShiftAmount(V, DemandedElts, Depth);
}
-std::optional<uint64_t>
+std::optional<unsigned>
SelectionDAG::getValidMinimumShiftAmount(SDValue V, const APInt &DemandedElts,
unsigned Depth) const {
assert((V.getOpcode() == ISD::SHL || V.getOpcode() == ISD::SRL ||
@@ -3333,7 +3333,7 @@ SelectionDAG::getValidMinimumShiftAmount(SDValue V, const APInt &DemandedElts,
return std::nullopt;
}
-std::optional<uint64_t>
+std::optional<unsigned>
SelectionDAG::getValidMinimumShiftAmount(SDValue V, unsigned Depth) const {
EVT VT = V.getValueType();
APInt DemandedElts = VT.isFixedLengthVector()
@@ -3342,7 +3342,7 @@ SelectionDAG::getValidMinimumShiftAmount(SDValue V, unsigned Depth) const {
return getValidMinimumShiftAmount(V, DemandedElts, Depth);
}
-std::optional<uint64_t>
+std::optional<unsigned>
SelectionDAG::getValidMaximumShiftAmount(SDValue V, const APInt &DemandedElts,
unsigned Depth) const {
assert((V.getOpcode() == ISD::SHL || V.getOpcode() == ISD::SRL ||
@@ -3354,7 +3354,7 @@ SelectionDAG::getValidMaximumShiftAmount(SDValue V, const APInt &DemandedElts,
return std::nullopt;
}
-std::optional<uint64_t>
+std::optional<unsigned>
SelectionDAG::getValidMaximumShiftAmount(SDValue V, unsigned Depth) const {
EVT VT = V.getValueType();
APInt DemandedElts = VT.isFixedLengthVector()
@@ -3828,7 +3828,7 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
Known = KnownBits::shl(Known, Known2, NUW, NSW, ShAmtNonZero);
// Minimum shift low bits are known zero.
- if (std::optional<uint64_t> ShMinAmt =
+ if (std::optional<unsigned> ShMinAmt =
getValidMinimumShiftAmount(Op, DemandedElts, Depth + 1))
Known.Zero.setLowBits(*ShMinAmt);
break;
@@ -3840,7 +3840,7 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
Op->getFlags().hasExact());
// Minimum shift high bits are known zero.
- if (std::optional<uint64_t> ShMinAmt =
+ if (std::optional<unsigned> ShMinAmt =
getValidMinimumShiftAmount(Op, DemandedElts, Depth + 1))
Known.Zero.setHighBits(*ShMinAmt);
break;
@@ -4871,15 +4871,15 @@ unsigned SelectionDAG::ComputeNumSignBits(SDValue Op, const APInt &DemandedElts,
case ISD::SRA:
Tmp = ComputeNumSignBits(Op.getOperand(0), DemandedElts, Depth + 1);
// SRA X, C -> adds C sign bits.
- if (std::optional<uint64_t> ShAmt =
+ if (std::optional<unsigned> ShAmt =
getValidMinimumShiftAmount(Op, DemandedElts, Depth + 1))
- Tmp = std::min<uint64_t>(Tmp + *ShAmt, VTBits);
+ Tmp = std::min(Tmp + *ShAmt, VTBits);
return Tmp;
case ISD::SHL:
if (std::optional<ConstantRange> ShAmtRange =
getValidShiftAmountRange(Op, DemandedElts, Depth + 1)) {
- uint64_t MaxShAmt = ShAmtRange->getUnsignedMax().getZExtValue();
- uint64_t MinShAmt = ShAmtRange->getUnsignedMin().getZExtValue();
+ unsigned MaxShAmt = ShAmtRange->getUnsignedMax().getZExtValue();
+ unsigned MinShAmt = ShAmtRange->getUnsignedMin().getZExtValue();
// Try to look through ZERO/SIGN/ANY_EXTEND. If all extended bits are
// shifted out, then we can compute the number of sign bits for the
// operand being extended. A future improvement could be to pass along the
@@ -4890,7 +4890,7 @@ unsigned SelectionDAG::ComputeNumSignBits(SDValue Op, const APInt &DemandedElts,
EVT ExtVT = Ext.getValueType();
SDValue Extendee = Ext.getOperand(0);
EVT ExtendeeVT = Extendee.getValueType();
- uint64_t SizeDifference =
+ unsigned SizeDifference =
ExtVT.getScalarSizeInBits() - ExtendeeVT.getScalarSizeInBits();
if (SizeDifference <= MinShAmt) {
Tmp = SizeDifference +
diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index 60c4bb574d4bb..6c86706a008ec 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -832,7 +832,7 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
case ISD::SHL: {
// If we are only demanding sign bits then we can use the shift source
// directly.
- if (std::optional<uint64_t> MaxSA =
+ if (std::optional<unsigned> MaxSA =
DAG.getValidMaximumShiftAmount(Op, DemandedElts, Depth + 1)) {
SDValue Op0 = Op.getOperand(0);
unsigned ShAmt = *MaxSA;
@@ -847,7 +847,7 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
case ISD::SRL: {
// If we are only demanding sign bits then we can use the shift source
// directly.
- if (std::optional<uint64_t> MaxSA =
+ if (std::optional<unsigned> MaxSA =
DAG.getValidMaximumShiftAmount(Op, DemandedElts, Depth + 1)) {
SDValue Op0 = Op.getOperand(0);
unsigned ShAmt = *MaxSA;
@@ -1780,7 +1780,7 @@ bool TargetLowering::SimplifyDemandedBits(
SDValue Op1 = Op.getOperand(1);
EVT ShiftVT = Op1.getValueType();
- if (std::optional<uint64_t> KnownSA =
+ if (std::optional<unsigned> KnownSA =
TLO.DAG.getValidShiftAmount(Op, DemandedElts, Depth + 1)) {
unsigned ShAmt = *KnownSA;
if (ShAmt == 0)
@@ -1792,7 +1792,7 @@ bool TargetLowering::SimplifyDemandedBits(
// TODO - support non-uniform vector amounts.
if (Op0.getOpcode() == ISD::SRL) {
if (!DemandedBits.intersects(APInt::getLowBitsSet(BitWidth, ShAmt))) {
- if (std::optional<uint64_t> InnerSA =
+ if (std::optional<unsigned> InnerSA =
TLO.DAG.getValidShiftAmount(Op0, DemandedElts, Depth + 2)) {
unsigned C1 = *InnerSA;
unsigned Opc = ISD::SHL;
@@ -1832,7 +1832,7 @@ bool TargetLowering::SimplifyDemandedBits(
// TODO - support non-uniform vector amounts.
if (InnerOp.getOpcode() == ISD::SRL && Op0.hasOneUse() &&
InnerOp.hasOneUse()) {
- if (std::optional<uint64_t> SA2 = TLO.DAG.getValidShiftAmount(
+ if (std::optional<unsigned> SA2 = TLO.DAG.getValidShiftAmount(
InnerOp, DemandedElts, Depth + 2)) {
unsigned InnerShAmt = *SA2;
if (InnerShAmt < ShAmt && InnerShAmt < InnerBits &&
@@ -1949,7 +1949,7 @@ bool TargetLowering::SimplifyDemandedBits(
// If we are only demanding sign bits then we can use the shift source
// directly.
- if (std::optional<uint64_t> MaxSA =
+ if (std::optional<unsigned> MaxSA =
TLO.DAG.getValidMaximumShiftAmount(Op, DemandedElts, Depth + 1)) {
unsigned ShAmt = *MaxSA;
unsigned NumSignBits =
@@ -1965,7 +1965,7 @@ bool TargetLowering::SimplifyDemandedBits(
SDValue Op1 = Op.getOperand(1);
EVT ShiftVT = Op1.getValueType();
- if (std::optional<uint64_t> KnownSA =
+ if (std::optional<unsigned> KnownSA =
TLO.DAG.getValidShiftAmount(Op, DemandedElts, Depth + 1)) {
unsigned ShAmt = *KnownSA;
if (ShAmt == 0)
@@ -1977,7 +1977,7 @@ bool TargetLowering::SimplifyDemandedBits(
// TODO - support non-uniform vector amounts.
if (Op0.getOpcode() == ISD::SHL) {
if (!DemandedBits.intersects(APInt::getHighBitsSet(BitWidth, ShAmt))) {
- if (std::optional<uint64_t> InnerSA =
+ if (std::optional<unsigned> InnerSA =
TLO.DAG.getValidShiftAmount(Op0, DemandedElts, Depth + 2)) {
unsigned C1 = *InnerSA;
unsigned Opc = ISD::SRL;
@@ -1997,7 +1997,7 @@ bool TargetLowering::SimplifyDemandedBits(
// single sra. We can do this if the top bits are never demanded.
if (Op0.getOpcode() == ISD::SRA && Op0.hasOneUse()) {
if (!DemandedBits.intersects(APInt::getHighBitsSet(BitWidth, ShAmt))) {
- if (std::optional<uint64_t> InnerSA =
+ if (std::optional<unsigned> InnerSA =
TLO.DAG.getValidShiftAmount(Op0, DemandedElts, Depth + 2)) {
unsigned C1 = *InnerSA;
// Clamp the combined shift amount if it exceeds the bit width.
@@ -2062,7 +2062,7 @@ bool TargetLowering::SimplifyDemandedBits(
// If we are only demanding sign bits then we can use the shift source
// directly.
- if (std::optional<uint64_t> MaxSA =
+ if (std::optional<unsigned> MaxSA =
TLO.DAG.getValidMaximumShiftAmount(Op, DemandedElts, Depth + 1)) {
unsigned ShAmt = *MaxSA;
// Must already be signbits in DemandedBits bounds, and can't demand any
@@ -2101,7 +2101,7 @@ bool TargetLowering::SimplifyDemandedBits(
if (DemandedBits.isOne())
return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::SRL, dl, VT, Op0, Op1));
- if (std::optional<uint64_t> KnownSA =
+ if (std::optional<unsigned> KnownSA =
TLO.DAG.getValidShiftAmount(Op, DemandedElts, Depth + 1)) {
unsigned ShAmt = *KnownSA;
if (ShAmt == 0)
@@ -2110,7 +2110,7 @@ bool TargetLowering::SimplifyDemandedBits(
// fold (sra (shl x, c1), c1) -> sext_inreg for some c1 and target
// supports sext_inreg.
if (Op0.getOpcode() == ISD::SHL) {
- if (std::optional<uint64_t> InnerSA =
+ if (std::optional<unsigned> InnerSA =
TLO.DAG.getValidShiftAmount(Op0, DemandedElts, Depth + 2)) {
unsigned LowBits = BitWidth - ShAmt;
EVT ExtVT = EVT::getIntegerVT(*TLO.DAG.getContext(), LowBits);
@@ -2657,11 +2657,11 @@ bool TargetLowering::SimplifyDemandedBits(
break;
}
- std::optional<uint64_t> ShAmtC =
+ std::optional<unsigned> ShAmtC =
TLO.DAG.getValidShiftAmount(Src, DemandedElts, Depth + 2);
if (!ShAmtC || *ShAmtC >= BitWidth)
break;
- uint64_t ShVal = *ShAmtC;
+ unsigned ShVal = *ShAmtC;
APInt HighBits =
APInt::getHighBitsSet(OperandBitWidth, OperandBitWidth - BitWidth);
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index d78cf00a5a2fc..08ae0d52d795e 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -21252,7 +21252,7 @@ static SDValue matchTruncateWithPACK(unsigned &PackOpcode, EVT DstVT,
// the truncation then we can use PACKSS by converting the srl to a sra.
// SimplifyDemandedBits often relaxes sra to srl so we need to reverse it.
if (In.getOpcode() == ISD::SRL && In->hasOneUse())
- if (std::optional<uint64_t> ShAmt = DAG.getValidShiftAmount(In)) {
+ if (std::optional<unsigned> ShAmt = DAG.getValidShiftAmount(In)) {
if (*ShAmt == MinSignBits) {
PackOpcode = X86ISD::PACKSS;
return DAG.getNode(ISD::SRA, DL, SrcVT, In->ops());
@@ -48383,7 +48383,7 @@ static SDValue checkSignTestSetCCCombine(SDValue Cmp, X86::CondCode &CC,
// If Src came from a SHL (probably from an expanded SIGN_EXTEND_INREG), then
// peek through and adjust the TEST bit.
if (Src.getOpcode() == ISD::SHL) {
- if (std::optional<uint64_t> ShiftAmt = DAG.getValidShiftAmount(Src)) {
+ if (std::optional<unsigned> ShiftAmt = DAG.getValidShiftAmount(Src)) {
Src = Src.getOperand(0);
BitMask.lshrInPlace(*ShiftAmt);
}
@@ -54169,10 +54169,10 @@ static SDValue combineLRINT_LLRINT(SDNode *N, SelectionDAG &DAG,
static SDValue combinei64TruncSrlConstant(SDValue N, EVT VT, SelectionDAG &DAG,
const SDLoc &DL) {
assert(N.getOpcode() == ISD::SRL && "Unknown shift opcode");
- std::optional<uint64_t> ValidSrlConst = DAG.getValidShiftAmount(N);
+ std::optional<unsigned> ValidSrlConst = DAG.getValidShiftAmount(N);
if (!ValidSrlConst)
return SDValue();
- uint64_t SrlConstVal = *ValidSrlConst;
+ unsigned SrlConstVal = *ValidSrlConst;
SDValue Op = N.getOperand(0);
unsigned Opcode = Op.getOpcode();
``````````
</details>
https://github.com/llvm/llvm-project/pull/156224
More information about the llvm-commits
mailing list